From a1d0c3bcc89ca0473c5e52308ec152de5d2ae0ed Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 4 Jul 2025 13:52:26 +0200 Subject: [PATCH 01/11] Allow `DEFAULT_QUEUE` to be configurable from airflow settings in Task SDK (#52786) Co-authored-by: David Blain (cherry picked from commit c65dc8e09caf6090e8a0e680c2a2cc8861889ca3) --- .../src/airflow/sdk/definitions/_internal/abstractoperator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index ec2fefa0a08b4..95c8d9d8287cb 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -30,6 +30,7 @@ import methodtools +from airflow.configuration import conf from airflow.sdk.definitions._internal.mixins import DependencyMixin from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions._internal.templater import Templater @@ -61,7 +62,7 @@ MINIMUM_PRIORITY_WEIGHT: int = -2147483648 MAXIMUM_PRIORITY_WEIGHT: int = 2147483647 DEFAULT_EXECUTOR: str | None = None -DEFAULT_QUEUE: str = "default" +DEFAULT_QUEUE: str = conf.get("operators", "default_queue", "default") DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False DEFAULT_RETRIES: int = 0 From 743e6ac6533ce8b0ccac0a79f03b43b6d7d7d3a1 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 4 Jul 2025 23:19:13 +0530 Subject: [PATCH 02/11] Fix task configuration defaults for AbstractOperator (#52871) Some defaults weren't being taken from configuration -- this is now fixed. (cherry picked from commit 04d2d3b8b8ce831d8725ca661a61f1e310d27eff) --- airflow-core/src/airflow/models/__init__.py | 21 ++++++++++++ .../src/airflow/models/abstractoperator.py | 34 ------------------- .../serialization/test_dag_serialization.py | 2 ++ .../edge3/executors/edge_executor.py | 6 +++- .../definitions/_internal/abstractoperator.py | 20 +++++++---- .../sdk/definitions/decorators/__init__.pyi | 4 +-- .../airflow/sdk/definitions/mappedoperator.py | 10 ++---- 7 files changed, 46 insertions(+), 51 deletions(-) delete mode 100644 airflow-core/src/airflow/models/abstractoperator.py diff --git a/airflow-core/src/airflow/models/__init__.py b/airflow-core/src/airflow/models/__init__.py index 9274ae7a79f3f..b5a8e61e49ae7 100644 --- a/airflow-core/src/airflow/models/__init__.py +++ b/airflow-core/src/airflow/models/__init__.py @@ -19,6 +19,8 @@ from __future__ import annotations +from airflow.utils.deprecation_tools import add_deprecated_classes + # Do not add new models to this -- this is for compat only __all__ = [ "DAG", @@ -141,3 +143,22 @@ def __getattr__(name): from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions.param import Param from airflow.sdk.execution_time.xcom import XCom + +__deprecated_classes = { + "abstractoperator": { + "AbstractOperator": "airflow.sdk.definitions._internal.abstractoperator.AbstractOperator", + "NotMapped": "airflow.sdk.definitions._internal.abstractoperator.NotMapped", + "TaskStateChangeCallback": "airflow.sdk.definitions._internal.abstractoperator.TaskStateChangeCallback", + "DEFAULT_OWNER": "airflow.sdk.definitions._internal.abstractoperator.DEFAULT_OWNER", + "DEFAULT_QUEUE": "airflow.sdk.definitions._internal.abstractoperator.DEFAULT_QUEUE", + "DEFAULT_TASK_EXECUTION_TIMEOUT": "airflow.sdk.definitions._internal.abstractoperator.DEFAULT_TASK_EXECUTION_TIMEOUT", + }, + "param": { + "Param": "airflow.sdk.definitions.param.Param", + "ParamsDict": "airflow.sdk.definitions.param.ParamsDict", + }, + "baseoperatorlink": { + "BaseOperatorLink": "airflow.sdk.bases.operatorlink.BaseOperatorLink", + }, +} +add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow-core/src/airflow/models/abstractoperator.py b/airflow-core/src/airflow/models/abstractoperator.py deleted file mode 100644 index e5b5f7dc81f45..0000000000000 --- a/airflow-core/src/airflow/models/abstractoperator.py +++ /dev/null @@ -1,34 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import datetime - -from airflow.configuration import conf -from airflow.sdk.definitions._internal.abstractoperator import ( - AbstractOperator as AbstractOperator, - NotMapped as NotMapped, # Re-export this for compat - TaskStateChangeCallback as TaskStateChangeCallback, -) - -DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") -DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") - -DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( - "core", "default_task_execution_timeout" -) diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index a7508b3ea520d..f4a719496f27d 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -187,6 +187,7 @@ "bash_command": "echo {{ task.task_id }}", "task_type": "BashOperator", "_task_module": "airflow.providers.standard.operators.bash", + "owner": "airflow", "pool": "default_pool", "is_setup": False, "is_teardown": False, @@ -3163,6 +3164,7 @@ def test_handle_v1_serdag(): "_task_type": "BashOperator", # Slightly difference from v2-10-stable here, we manually changed this path "_task_module": "airflow.providers.standard.operators.bash", + "owner": "airflow", "pool": "default_pool", "is_setup": False, "is_teardown": False, diff --git a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py index 55d5414cd8965..20bf71623d179 100644 --- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py +++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py @@ -30,7 +30,11 @@ from airflow.cli.cli_config import GroupCommand from airflow.configuration import conf from airflow.executors.base_executor import BaseExecutor -from airflow.models.abstractoperator import DEFAULT_QUEUE + +try: + from airflow.models.abstractoperator import DEFAULT_QUEUE +except (ImportError, AttributeError): + from airflow.sdk.definitions._internal.abstractoperator import DEFAULT_QUEUE from airflow.models.taskinstance import TaskInstance, TaskInstanceState from airflow.providers.edge3.cli.edge_command import EDGE_COMMANDS from airflow.providers.edge3.models.edge_job import EdgeJobModel diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index 95c8d9d8287cb..8934cd0e4532c 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -51,7 +51,7 @@ TaskStateChangeCallback = Callable[[Context], None] -DEFAULT_OWNER: str = "airflow" +DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") DEFAULT_POOL_SLOTS: int = 1 DEFAULT_POOL_NAME = "default_pool" DEFAULT_PRIORITY_WEIGHT: int = 1 @@ -62,17 +62,23 @@ MINIMUM_PRIORITY_WEIGHT: int = -2147483648 MAXIMUM_PRIORITY_WEIGHT: int = 2147483647 DEFAULT_EXECUTOR: str | None = None -DEFAULT_QUEUE: str = conf.get("operators", "default_queue", "default") +DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False -DEFAULT_RETRIES: int = 0 -DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(seconds=300) -MAX_RETRY_DELAY: int = 24 * 60 * 60 +DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0) +DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta( + seconds=conf.getint("core", "default_task_retry_delay", fallback=300) +) +MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) # TODO: Task-SDK -- these defaults should be overridable from the Airflow config DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS -DEFAULT_WEIGHT_RULE: WeightRule = WeightRule.DOWNSTREAM -DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = None +DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( + conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) +) +DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( + "core", "default_task_execution_timeout" +) log = logging.getLogger(__name__) diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi index 30e921f2f4881..038f94a0bd8f5 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi +++ b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi @@ -496,7 +496,7 @@ class TaskDecoratorCollection: """ # [END decorator_signature] @overload - def kubernetes( + def kubernetes( # type: ignore[misc] self, *, multiple_outputs: bool | None = None, @@ -670,7 +670,7 @@ class TaskDecoratorCollection: @overload def kubernetes(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... @overload - def kubernetes_cmd( + def kubernetes_cmd( # type: ignore[misc] self, *, args_only: bool = False, # Added by _KubernetesCmdDecoratedOperator. diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index cb24a7cc6bdb7..8c84ec6bbeeb0 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -41,6 +41,7 @@ DEFAULT_WEIGHT_RULE, AbstractOperator, NotMapped, + TaskStateChangeCallback, ) from airflow.sdk.definitions._internal.expandinput import ( DictOfListsExpandInput, @@ -50,7 +51,7 @@ from airflow.sdk.definitions._internal.types import NOTSET from airflow.serialization.enums import DagAttributeTypes from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy -from airflow.typing_compat import Literal +from airflow.typing_compat import Literal, TypeGuard from airflow.utils.helpers import is_container, prevent_duplicates from airflow.utils.xcom import XCOM_RETURN_KEY @@ -60,9 +61,6 @@ import jinja2 # Slow import. import pendulum - from airflow.models.abstractoperator import ( - TaskStateChangeCallback, - ) from airflow.models.expandinput import ( OperatorExpandArgument, OperatorExpandKwargsArgument, @@ -76,14 +74,12 @@ from airflow.sdk.types import Operator from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.triggers.base import StartTriggerArgs - from airflow.typing_compat import TypeGuard from airflow.utils.context import Context from airflow.utils.operator_resources import Resources from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule - TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] - +TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] ValidationSource = Union[Literal["expand"], Literal["partial"]] From ad876558a8468c4bdc404610b360bbf3eada8fed Mon Sep 17 00:00:00 2001 From: Elad Kalif <45845474+eladkal@users.noreply.github.com> Date: Mon, 14 Jul 2025 10:52:03 +0300 Subject: [PATCH 03/11] Add note about ruff rules and preview flag (#53331) (cherry picked from commit d4d4cce8d290fd5b4b51bdf217419c5b6ab889cc) --- .../installation/upgrading_to_airflow3.rst | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/airflow-core/docs/installation/upgrading_to_airflow3.rst b/airflow-core/docs/installation/upgrading_to_airflow3.rst index 560f352ee4138..8b733ad27f061 100644 --- a/airflow-core/docs/installation/upgrading_to_airflow3.rst +++ b/airflow-core/docs/installation/upgrading_to_airflow3.rst @@ -108,9 +108,25 @@ Some changes can be automatically fixed. To do so, run the following command: ruff check dag/ --select AIR301 --fix --preview +Some of the fixes are marked as unsafe. Unsafe fixes usually do not break dag code. They're marked as unsafe as they may change some runtime behavior. For more information, see `Fix Safety `_. +To trigger these fixes, run the following command: + +.. code-block:: bash + + ruff check dags/ --select AIR301 --fix --unsafe-fixes --preview + +.. note:: + Ruff has strict policy about when a rule becomes stable. Till it does you must use --preview flag. + The progress of Airflow Ruff rule become stable can be tracked in https://github.com/astral-sh/ruff/issues/17749 + That said, from Airflow side the rules are perfectly fine to be used. + +.. note:: + + In AIR rules, unsafe fixes involve changing import paths while keeping the name of the imported member the same. For instance, changing the import from ``from airflow.sensors.base_sensor_operator import BaseSensorOperator`` to ``from airflow.sdk.bases.sensor import BaseSensorOperator`` requires ruff to remove the original import before adding the new one. In contrast, safe fixes include changes to both the member name and the import path, such as changing ``from airflow.datasets import Dataset`` to `from airflow.sdk import Asset``. These adjustments do not require ruff to remove the old import. To remove unused legacy imports, it is necessary to enable the `unused-import` rule (F401) . + You can also configure these flags through configuration files. See `Configuring Ruff `_ for details. -Step 4: Install the Standard Providers +Step 4: Install the Standard Provider -------------------------------------- - Some of the commonly used Operators which were bundled as part of the ``airflow-core`` package (for example ``BashOperator`` and ``PythonOperator``) From 9fd6c3e3f782713f3ae1b0155b4377ac9e520c67 Mon Sep 17 00:00:00 2001 From: r-richmond Date: Thu, 17 Jul 2025 10:27:53 -0700 Subject: [PATCH 04/11] Fix broken link in advanced logging config docs (#53460) (cherry picked from commit df5c949db5209e3e0b27af537b2ca53c7aefde64) --- .../logging-monitoring/advanced-logging-configuration.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/docs/administration-and-deployment/logging-monitoring/advanced-logging-configuration.rst b/airflow-core/docs/administration-and-deployment/logging-monitoring/advanced-logging-configuration.rst index 342de5e016a3b..fa34e74f3140b 100644 --- a/airflow-core/docs/administration-and-deployment/logging-monitoring/advanced-logging-configuration.rst +++ b/airflow-core/docs/administration-and-deployment/logging-monitoring/advanced-logging-configuration.rst @@ -34,7 +34,7 @@ Some configuration options require that the logging config class be overwritten. configuration of Airflow and modifying it to suit your needs. The default configuration can be seen in the -`airflow_local_settings.py template `_ +`airflow_local_settings.py template `_ and you can see the loggers and handlers used there. See :ref:`Configuring local settings ` for details on how to From 9e26810a1fc978f2d18a524fe1b910cd3056a8c7 Mon Sep 17 00:00:00 2001 From: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> Date: Thu, 17 Jul 2025 14:53:42 -0600 Subject: [PATCH 05/11] Update dag bundles docs; add s3, fix git classpath (#53473) (cherry picked from commit bf5fd5ff69851db394e50a10e825a744487c0779) --- .../docs/administration-and-deployment/dag-bundles.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/airflow-core/docs/administration-and-deployment/dag-bundles.rst b/airflow-core/docs/administration-and-deployment/dag-bundles.rst index 057d354b1f523..48401b0c1439c 100644 --- a/airflow-core/docs/administration-and-deployment/dag-bundles.rst +++ b/airflow-core/docs/administration-and-deployment/dag-bundles.rst @@ -50,6 +50,9 @@ Airflow supports multiple types of dag Bundles, each catering to specific use ca **airflow.providers.git.bundles.git.GitDagBundle** These bundles integrate with Git repositories, allowing Airflow to fetch dags directly from a repository. +**airflow.providers.amazon.aws.bundles.s3.S3DagBundle** + These bundles reference an S3 bucket containing DAG files. They do not support versioning of the bundle, meaning tasks always run using the latest code. + Configuring dag bundles ----------------------- @@ -65,7 +68,7 @@ For example, adding multiple dag bundles to your ``airflow.cfg`` file: dag_bundle_config_list = [ { "name": "my_git_repo", - "classpath": "airflow.dag_processing.bundles.git.GitDagBundle", + "classpath": "airflow.providers.git.bundles.git.GitDagBundle", "kwargs": {"tracking_ref": "main", "git_conn_id": "my_git_conn"} }, { From 63387be0043b5ad69f1f942efb192e741bf85943 Mon Sep 17 00:00:00 2001 From: Evgenii Prusov <114025336+evgenii-prusov@users.noreply.github.com> Date: Tue, 22 Jul 2025 02:57:25 +0200 Subject: [PATCH 06/11] Fixed Task group names duplication in Task's task_id for MappedOperator (#53532) Co-authored-by: Wei Lee Co-authored-by: Evgenii Prusov (cherry picked from commit 6b618efa91b4ef80c4821537b30f14a49c2badb6) --- task-sdk/src/airflow/sdk/bases/operator.py | 8 +++++-- .../definitions/test_mappedoperator.py | 22 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 5de513bda9e02..a3d6f58edfd08 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -1014,9 +1014,13 @@ def __init__( ): # Note: Metaclass handles passing in the DAG/TaskGroup from active context manager, if any - self.task_id = task_group.child_id(task_id) if task_group else task_id - if not self.__from_mapped and task_group: + # Only apply task_group prefix if this operator was not created from a mapped operator + # Mapped operators already have the prefix applied during their creation + if task_group and not self.__from_mapped: + self.task_id = task_group.child_id(task_id) task_group.add(self) + else: + self.task_id = task_id super().__init__() self.task_group = task_group diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py index f9c075352a0d4..7a4b43c81040a 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py +++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py @@ -691,3 +691,25 @@ def group(x): ), ] ) + + +def test_mapped_operator_in_task_group_no_duplicate_prefix(): + """Test that task_id doesn't get duplicated prefix when unmapping a mapped operator in a task group.""" + from airflow.sdk.definitions.taskgroup import TaskGroup + + with DAG("test-dag"): + with TaskGroup(group_id="tg1") as tg1: + # Create a mapped task within the task group + mapped_task = MockOperator.partial(task_id="mapped_task", arg1="a").expand(arg2=["a", "b", "c"]) + + # Check the mapped operator has correct task_id + assert mapped_task.task_id == "tg1.mapped_task" + assert mapped_task.task_group == tg1 + assert mapped_task.task_group.group_id == "tg1" + + # Simulate what happens during execution - unmap the operator + # unmap expects resolved kwargs + unmapped = mapped_task.unmap({"arg2": "a"}) + + # The unmapped operator should have the same task_id, not a duplicate prefix + assert unmapped.task_id == "tg1.mapped_task", f"Expected 'tg1.mapped_task' but got '{unmapped.task_id}'" From de52b4350e237a3fad9be0940d0a819eb0e16fd5 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 29 Jul 2025 15:04:37 +0530 Subject: [PATCH 07/11] Fix custom xcom backend serialize when BaseXCom.get_all is used (#53814) (cherry picked from commit a8c4ba35351afeef8f15c9627e04241a29054fe4) --- task-sdk/src/airflow/sdk/bases/xcom.py | 11 ++- .../execution_time/test_task_runner.py | 75 +++++++++++++++++++ 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 82df8d151ab13..423fde1201237 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -17,6 +17,7 @@ from __future__ import annotations +import collections from typing import Any, Protocol import structlog @@ -30,6 +31,9 @@ XComSequenceSliceResult, ) +# Lightweight wrapper for XCom values +_XComValueWrapper = collections.namedtuple("_XComValueWrapper", "value") + log = structlog.get_logger(logger_name="task") @@ -290,7 +294,6 @@ def get_all( :return: List of all XCom values if found. """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - from airflow.serialization.serde import deserialize msg = SUPERVISOR_COMMS.send( msg=GetXComSequenceSlice( @@ -307,10 +310,10 @@ def get_all( if not isinstance(msg, XComSequenceSliceResult): raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}") - result = deserialize(msg.root) - if not result: + if not msg.root: return None - return result + + return [cls.deserialize_value(_XComValueWrapper(value)) for value in msg.root] @staticmethod def serialize_value( diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 362b89e3ad0a5..4b445404d8a29 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2014,6 +2014,81 @@ def execute(self, context): for x in mock_supervisor_comms.send.call_args_list ) + def test_get_all_uses_custom_deserialize_value(self, mock_supervisor_comms): + """ + Tests that XCom.get_all() calls the custom deserialize_value method. + """ + + class CustomXCom(BaseXCom): + @classmethod + def deserialize_value(cls, result): + """Custom deserialization that adds a prefix to show it was called.""" + original_value = super().deserialize_value(result) + return f"from custom xcom deserialize:{original_value}" + + serialized_values = ["value1", "value2", "value3"] + mock_supervisor_comms.send.return_value = XComSequenceSliceResult(root=serialized_values) + + result = CustomXCom.get_all(key="test_key", dag_id="test_dag", task_id="test_task", run_id="test_run") + + expected = [ + "from custom xcom deserialize:value1", + "from custom xcom deserialize:value2", + "from custom xcom deserialize:value3", + ] + assert result == expected + + @pytest.mark.parametrize( + ("include_prior_dates", "expected_value"), + [ + pytest.param(True, True, id="include_prior_dates_true"), + pytest.param(False, False, id="include_prior_dates_false"), + pytest.param(None, False, id="include_prior_dates_default"), + ], + ) + def test_xcom_pull_with_include_prior_dates( + self, + create_runtime_ti, + mock_supervisor_comms, + include_prior_dates, + expected_value, + ): + """Test that xcom_pull with include_prior_dates parameter correctly behaves as we expect.""" + task = BaseOperator(task_id="pull_task") + runtime_ti = create_runtime_ti(task=task) + + value = {"previous_run_data": "test_value"} + ser_value = BaseXCom.serialize_value(value) + + def mock_send_side_effect(*args, **kwargs): + msg = kwargs.get("msg") or args[0] + if isinstance(msg, GetXComSequenceSlice): + assert msg.include_prior_dates is expected_value, ( + f"include_prior_dates should be {expected_value} in GetXComSequenceSlice" + ) + return XComSequenceSliceResult(root=[ser_value]) + return XComResult(key="test_key", value=None) + + mock_supervisor_comms.send.side_effect = mock_send_side_effect + kwargs = {"key": "test_key", "task_ids": "previous_task"} + if include_prior_dates is not None: + kwargs["include_prior_dates"] = include_prior_dates + result = runtime_ti.xcom_pull(**kwargs) + assert result == value + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXComSequenceSlice( + key="test_key", + dag_id=runtime_ti.dag_id, + run_id=runtime_ti.run_id, + task_id="previous_task", + start=None, + stop=None, + step=None, + include_prior_dates=expected_value, + ), + ) + class TestDagParamRuntime: DEFAULT_ARGS = { From 5935b73014b2f261415d985f821a4a228f584b14 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Mon, 11 Aug 2025 13:17:55 +0530 Subject: [PATCH 08/11] Restore ``get_previous_dagrun`` functionality for task context (#53655) Co-authored-by: Kaxil Naik (cherry picked from commit 35d23c3222472ef354acc34fd3a76237bced727b) --- .../execution_api/datamodels/taskinstance.py | 8 +- .../execution_api/routes/dag_runs.py | 45 +++++- .../execution_api/versions/__init__.py | 4 + .../execution_api/versions/v2025_08_10.py | 39 +++++ .../versions/head/test_dag_runs.py | 139 ++++++++++++++++++ .../versions/head/test_task_instances.py | 4 +- .../unit/dag_processing/test_processor.py | 2 + .../src/tests_common/pytest_plugin.py | 27 ++-- .../unit/openlineage/plugins/test_listener.py | 29 ++-- task-sdk/src/airflow/sdk/api/client.py | 18 +++ .../airflow/sdk/api/datamodels/_generated.py | 3 +- .../src/airflow/sdk/execution_time/comms.py | 17 +++ .../airflow/sdk/execution_time/supervisor.py | 7 + .../airflow/sdk/execution_time/task_runner.py | 27 ++++ task-sdk/src/airflow/sdk/types.py | 2 + task-sdk/tests/conftest.py | 2 + task-sdk/tests/task_sdk/api/test_client.py | 82 +++++++++++ .../task_sdk/execution_time/test_comms.py | 1 + .../execution_time/test_supervisor.py | 70 +++++++++ .../execution_time/test_task_runner.py | 63 ++++++++ 20 files changed, 555 insertions(+), 34 deletions(-) create mode 100644 airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 2d7968bbc625e..f1e42ef1e0b72 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -35,7 +35,12 @@ from airflow.api_fastapi.execution_api.datamodels.asset import AssetProfile from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse -from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState +from airflow.utils.state import ( + DagRunState, + IntermediateTIState, + TaskInstanceState as TIState, + TerminalTIState, +) from airflow.utils.types import DagRunType AwareDatetimeAdapter = TypeAdapter(AwareDatetime) @@ -292,6 +297,7 @@ class DagRun(StrictBaseModel): end_date: UtcDateTime | None clear_number: int = 0 run_type: DagRunType + state: DagRunState conf: Annotated[dict[str, Any], Field(default_factory=dict)] consumed_asset_events: list[AssetEventDagRunReference] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py index db27f7bd93ed8..d0dcef4faebe2 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py @@ -27,10 +27,12 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.execution_api.datamodels.dagrun import DagRunStateResponse, TriggerDAGRunPayload +from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun from airflow.exceptions import DagRunAlreadyExists from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag -from airflow.models.dagrun import DagRun +from airflow.models.dagrun import DagRun as DagRunModel +from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType router = APIRouter() @@ -140,7 +142,9 @@ def get_dagrun_state( session: SessionDep, ) -> DagRunStateResponse: """Get a DAG Run State.""" - dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)) + dag_run = session.scalar( + select(DagRunModel).where(DagRunModel.dag_id == dag_id, DagRunModel.run_id == run_id) + ) if dag_run is None: raise HTTPException( status.HTTP_404_NOT_FOUND, @@ -162,16 +166,45 @@ def get_dr_count( states: Annotated[list[str] | None, Query()] = None, ) -> int: """Get the count of DAG runs matching the given criteria.""" - query = select(func.count()).select_from(DagRun).where(DagRun.dag_id == dag_id) + query = select(func.count()).select_from(DagRunModel).where(DagRunModel.dag_id == dag_id) if logical_dates: - query = query.where(DagRun.logical_date.in_(logical_dates)) + query = query.where(DagRunModel.logical_date.in_(logical_dates)) if run_ids: - query = query.where(DagRun.run_id.in_(run_ids)) + query = query.where(DagRunModel.run_id.in_(run_ids)) if states: - query = query.where(DagRun.state.in_(states)) + query = query.where(DagRunModel.state.in_(states)) count = session.scalar(query) return count or 0 + + +@router.get("/{dag_id}/previous", status_code=status.HTTP_200_OK) +def get_previous_dagrun( + dag_id: str, + logical_date: UtcDateTime, + session: SessionDep, + state: Annotated[DagRunState | None, Query()] = None, +) -> DagRun | None: + """Get the previous DAG run before the given logical date, optionally filtered by state.""" + query = ( + select(DagRunModel) + .where( + DagRunModel.dag_id == dag_id, + DagRunModel.logical_date < logical_date, + ) + .order_by(DagRunModel.logical_date.desc()) + .limit(1) + ) + + if state: + query = query.where(DagRunModel.state == state) + + dag_run = session.scalar(query) + + if not dag_run: + return None + + return DagRun.model_validate(dag_run) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 5462f10297495..fee781a8ecf35 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -21,9 +21,13 @@ from airflow.api_fastapi.execution_api.versions.v2025_04_28 import AddRenderedMapIndexField from airflow.api_fastapi.execution_api.versions.v2025_05_20 import DowngradeUpstreamMapIndexes +from airflow.api_fastapi.execution_api.versions.v2025_08_10 import ( + AddDagRunStateFieldAndPreviousEndpoint, +) bundle = VersionBundle( HeadVersion(), + Version("2025-08-10", AddDagRunStateFieldAndPreviousEndpoint), Version("2025-05-20", DowngradeUpstreamMapIndexes), Version("2025-04-28", AddRenderedMapIndexField), Version("2025-04-11"), diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py new file mode 100644 index 0000000000000..188eaec2d79ce --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, endpoint, schema + +from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun, TIRunContext + + +class AddDagRunStateFieldAndPreviousEndpoint(VersionChange): + """Add the `state` field to DagRun model and `/dag-runs/{dag_id}/previous` endpoint.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(DagRun).field("state").didnt_exist, + endpoint("/dag-runs/{dag_id}/previous", ["GET"]).didnt_exist, + ) + + @convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type] + def remove_state_from_dag_run(response: ResponseInfo) -> None: # type: ignore[misc] + """Remove the `state` field from the dag_run object when converting to the previous version.""" + if "dag_run" in response.body and isinstance(response.body["dag_run"], dict): + response.body["dag_run"].pop("state", None) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py index f9f8d489d3d26..ac414f53ee83a 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py @@ -314,3 +314,142 @@ def test_get_count_with_mixed_states(self, client, session, dag_maker): ) assert response.status_code == 200 assert response.json() == 2 + + +class TestGetPreviousDagRun: + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + def test_get_previous_dag_run_basic(self, client, session, dag_maker): + """Test getting the previous DAG run without state filtering.""" + dag_id = "test_get_previous_basic" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + # Create multiple DAG runs + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS + ) + dag_maker.create_dagrun( + run_id="run2", logical_date=timezone.datetime(2025, 1, 5), state=DagRunState.FAILED + ) + dag_maker.create_dagrun( + run_id="run3", logical_date=timezone.datetime(2025, 1, 10), state=DagRunState.SUCCESS + ) + session.commit() + + # Query for previous DAG run before 2025-01-10 + response = client.get( + f"/execution/dag-runs/{dag_id}/previous", + params={ + "logical_date": timezone.datetime(2025, 1, 10).isoformat(), + }, + ) + + assert response.status_code == 200 + result = response.json() + assert result["dag_id"] == dag_id + assert result["run_id"] == "run2" # Most recent before 2025-01-10 + assert result["state"] == "failed" + + def test_get_previous_dag_run_with_state_filter(self, client, session, dag_maker): + """Test getting the previous DAG run with state filtering.""" + dag_id = "test_get_previous_with_state" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + # Create multiple DAG runs with different states + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS + ) + dag_maker.create_dagrun( + run_id="run2", logical_date=timezone.datetime(2025, 1, 5), state=DagRunState.FAILED + ) + dag_maker.create_dagrun( + run_id="run3", logical_date=timezone.datetime(2025, 1, 8), state=DagRunState.SUCCESS + ) + session.commit() + + # Query for previous successful DAG run before 2025-01-10 + response = client.get( + f"/execution/dag-runs/{dag_id}/previous", + params={"logical_date": timezone.datetime(2025, 1, 10).isoformat(), "state": "success"}, + ) + + assert response.status_code == 200 + result = response.json() + assert result["dag_id"] == dag_id + assert result["run_id"] == "run3" # Most recent successful run before 2025-01-10 + assert result["state"] == "success" + + def test_get_previous_dag_run_no_previous_found(self, client, session, dag_maker): + """Test getting previous DAG run when none exists returns null.""" + dag_id = "test_get_previous_none" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + # Create only one DAG run - no previous should exist + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS + ) + + response = client.get(f"/execution/dag-runs/{dag_id}/previous?logical_date=2025-01-01T00:00:00Z") + + assert response.status_code == 200 + assert response.json() is None # Should return null + + def test_get_previous_dag_run_no_matching_state(self, client, session, dag_maker): + """Test getting previous DAG run with state filter that matches nothing returns null.""" + dag_id = "test_get_previous_no_match" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + # Create DAG runs with different states + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.FAILED + ) + dag_maker.create_dagrun( + run_id="run2", logical_date=timezone.datetime(2025, 1, 2), state=DagRunState.FAILED + ) + + # Look for previous success but only failed runs exist + response = client.get( + f"/execution/dag-runs/{dag_id}/previous?logical_date=2025-01-03T00:00:00Z&state=success" + ) + + assert response.status_code == 200 + assert response.json() is None + + def test_get_previous_dag_run_dag_not_found(self, client, session): + """Test getting previous DAG run for non-existent DAG returns 404.""" + response = client.get( + "/execution/dag-runs/nonexistent_dag/previous?logical_date=2025-01-01T00:00:00Z" + ) + + assert response.status_code == 200 + assert response.json() is None + + def test_get_previous_dag_run_invalid_state_parameter(self, client, session, dag_maker): + """Test that invalid state parameter returns 422 validation error.""" + dag_id = "test_get_previous_invalid_state" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS + ) + session.commit() + + response = client.get( + f"/execution/dag-runs/{dag_id}/previous?logical_date=2025-01-02T00:00:00Z&state=invalid_state" + ) + + assert response.status_code == 422 diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 3c107f61863f2..4f1e28a11f628 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -35,7 +35,7 @@ from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import Asset, TaskGroup, task, task_group from airflow.utils import timezone -from airflow.utils.state import State, TaskInstanceState, TerminalTIState +from airflow.utils.state import DagRunState, State, TaskInstanceState, TerminalTIState from tests_common.test_utils.db import ( clear_db_assets, @@ -155,6 +155,7 @@ def test_ti_run_state_to_running( ti = create_task_instance( task_id="test_ti_run_state_to_running", state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, session=session, start_date=instant, dag_id=str(uuid4()), @@ -184,6 +185,7 @@ def test_ti_run_state_to_running( "data_interval_end": instant_str, "run_after": instant_str, "start_date": instant_str, + "state": "running", "end_date": None, "run_type": "manual", "conf": {}, diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 1a3925c077db3..dc85bfb1b21fa 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -57,6 +57,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.sdk import DAG from airflow.sdk.api.client import Client +from airflow.sdk.api.datamodels._generated import DagRunState from airflow.sdk.execution_time import comms from airflow.utils import timezone from airflow.utils.session import create_session @@ -957,6 +958,7 @@ def fake_collect_dags(self, *args, **kwargs): logical_date=timezone.utcnow(), start_date=timezone.utcnow(), run_type="manual", + state=DagRunState.RUNNING, ) dag_run.run_after = timezone.utcnow() diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 2288fefc26880..e161771b84aad 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2134,7 +2134,7 @@ def _create_task_instance( should_retry: bool | None = None, max_tries: int | None = None, ) -> RuntimeTaskInstance: - from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext + from airflow.sdk.api.datamodels._generated import DagRun, DagRunState, TIRunContext from airflow.utils.types import DagRunType if not ti_id: @@ -2167,17 +2167,20 @@ def _create_task_instance( run_after = data_interval_end or logical_date or timezone.utcnow() ti_context = TIRunContext( - dag_run=DagRun( - dag_id=dag_id, - run_id=run_id, - logical_date=logical_date, # type: ignore - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, - start_date=start_date, # type: ignore - run_type=run_type, # type: ignore - run_after=run_after, # type: ignore - conf=conf, - consumed_asset_events=[], + dag_run=DagRun.model_validate( + { + "dag_id": dag_id, + "run_id": run_id, + "logical_date": logical_date, # type: ignore + "data_interval_start": data_interval_start, + "data_interval_end": data_interval_end, + "start_date": start_date, # type: ignore + "run_type": run_type, # type: ignore + "run_after": run_after, # type: ignore + "conf": conf, + "consumed_asset_events": [], + **({"state": DagRunState.RUNNING} if "state" in DagRun.model_fields else {}), + } ), task_reschedule_count=task_reschedule_count, max_tries=task_retries if max_tries is None else max_tries, diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py index eff8dc9efcb88..555725f0246a3 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py @@ -853,6 +853,7 @@ def _create_listener_and_task_instance( task_instance.dag_run.clear_number = 0 task_instance.dag_run.logical_date = timezone.datetime(2020, 1, 1, 1, 1, 1) task_instance.dag_run.run_after = timezone.datetime(2020, 1, 1, 1, 1, 1) + task_instance.dag_run.state = DagRunState.RUNNING task_instance.task = None task_instance.dag = None task_instance.task_id = "task_id" @@ -862,6 +863,7 @@ def _create_listener_and_task_instance( # RuntimeTaskInstance is used when on worker from airflow.sdk.api.datamodels._generated import ( DagRun as SdkDagRun, + DagRunState as SdkDagRunState, DagRunType, TaskInstance as SdkTaskInstance, TIRunContext, @@ -887,19 +889,20 @@ def _create_listener_and_task_instance( **sdk_task_instance.model_dump(exclude_unset=True), task=task, _ti_context_from_server=TIRunContext( - dag_run=SdkDagRun( - dag_id="dag_id", - run_id="dag_run_run_id", - logical_date=timezone.datetime(2020, 1, 1, 1, 1, 1), - data_interval_start=None, - data_interval_end=None, - start_date=timezone.datetime(2023, 1, 1, 13, 1, 1), - end_date=timezone.datetime(2023, 1, 3, 13, 1, 1), - clear_number=0, - run_type=DagRunType.MANUAL, - run_after=timezone.datetime(2023, 1, 3, 13, 1, 1), - conf=None, - consumed_asset_events=[], + dag_run=SdkDagRun.model_validate( + { + "dag_id": "dag_id_from_dagrun_not_ti", + "run_id": "dag_run_run_id_from_dagrun_not_ti", + "logical_date": timezone.datetime(2020, 1, 1, 1, 1, 1), + "start_date": timezone.datetime(2023, 1, 1, 13, 1, 1), + "end_date": timezone.datetime(2023, 1, 3, 13, 1, 1), + "run_type": DagRunType.MANUAL, + "run_after": timezone.datetime(2023, 1, 3, 13, 1, 1), + "consumed_asset_events": [], + **( + {"state": SdkDagRunState.RUNNING} if "state" in SdkDagRun.model_fields else {} + ), + } ), task_reschedule_count=0, max_tries=1, diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 3413e98189cfc..f36c46f45dea2 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -69,6 +69,7 @@ DRCount, ErrorResponse, OKResponse, + PreviousDagRunResult, SkipDownstreamTasks, TaskRescheduleStartDate, TICount, @@ -620,6 +621,23 @@ def get_count( resp = self.client.get("dag-runs/count", params=params) return DRCount(count=resp.json()) + def get_previous( + self, + dag_id: str, + logical_date: datetime, + state: str | None = None, + ) -> PreviousDagRunResult: + """Get the previous DAG run before the given logical date, optionally filtered by state.""" + params = { + "logical_date": logical_date.isoformat(), + } + + if state: + params["state"] = state + + resp = self.client.get(f"dag-runs/{dag_id}/previous", params=params) + return PreviousDagRunResult(dag_run=resp.json()) + class BearerAuth(httpx.Auth): def __init__(self, token: str): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 1dabd8c90228d..9cdc08379ac75 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -27,7 +27,7 @@ from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, RootModel -API_VERSION: Final[str] = "2025-05-20" +API_VERSION: Final[str] = "2025-08-10" class AssetAliasReferenceAssetEventDagRun(BaseModel): @@ -494,6 +494,7 @@ class DagRun(BaseModel): end_date: Annotated[AwareDatetime | None, Field(title="End Date")] = None clear_number: Annotated[int | None, Field(title="Clear Number")] = 0 run_type: DagRunType + state: DagRunState conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None consumed_asset_events: Annotated[list[AssetEventDagRunReference], Field(title="Consumed Asset Events")] diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 2c6dfea4e601c..1f9eb3141a319 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -70,6 +70,7 @@ AssetResponse, BundleInfo, ConnectionResponse, + DagRun, DagRunStateResponse, InactiveAssetsResponse, PrevSuccessfulDagRunResponse, @@ -492,6 +493,13 @@ def from_api_response(cls, dr_state_response: DagRunStateResponse) -> DagRunStat return cls(**dr_state_response.model_dump(exclude_defaults=True), type="DagRunStateResult") +class PreviousDagRunResult(BaseModel): + """Response containing previous DAG run information.""" + + dag_run: DagRun | None = None + type: Literal["PreviousDagRunResult"] = "PreviousDagRunResult" + + class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse): type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult" @@ -579,6 +587,7 @@ class SentFDs(BaseModel): XComSequenceSliceResult, InactiveAssetsResult, OKResponse, + PreviousDagRunResult, ], Field(discriminator="type"), ] @@ -775,6 +784,13 @@ class GetDagRunState(BaseModel): type: Literal["GetDagRunState"] = "GetDagRunState" +class GetPreviousDagRun(BaseModel): + dag_id: str + logical_date: AwareDatetime + state: str | None = None + type: Literal["GetPreviousDagRun"] = "GetPreviousDagRun" + + class GetAssetByName(BaseModel): name: str type: Literal["GetAssetByName"] = "GetAssetByName" @@ -853,6 +869,7 @@ class GetDRCount(BaseModel): GetDagRunState, GetDRCount, GetPrevSuccessfulDagRun, + GetPreviousDagRun, GetTaskRescheduleStartDate, GetTICount, GetTaskStates, diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 57224f66a5882..ccac2d8d2ad1c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -83,6 +83,7 @@ GetConnection, GetDagRunState, GetDRCount, + GetPreviousDagRun, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, GetTaskStates, @@ -1227,6 +1228,12 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: run_ids=msg.run_ids, states=msg.states, ) + elif isinstance(msg, GetPreviousDagRun): + resp = self.client.dag_runs.get_previous( + dag_id=msg.dag_id, + logical_date=msg.logical_date, + state=msg.state, + ) elif isinstance(msg, DeleteVariable): resp = self.client.variables.delete(msg.key) elif isinstance(msg, ValidateInletsAndOutlets): diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 62aa7d37b7f7a..a6bf2d11e1b13 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -44,6 +44,7 @@ from airflow.listeners.listener import get_listener_manager from airflow.sdk.api.datamodels._generated import ( AssetProfile, + DagRun, TaskInstance, TaskInstanceState, TIRunContext, @@ -65,10 +66,12 @@ ErrorResponse, GetDagRunState, GetDRCount, + GetPreviousDagRun, GetTaskRescheduleStartDate, GetTaskStates, GetTICount, InactiveAssetsResult, + PreviousDagRunResult, RescheduleTask, ResendLoggingFD, RetryTask, @@ -438,6 +441,30 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: return response.start_date + def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: + """Return the previous DAG run before the given logical date, optionally filtered by state.""" + context = self.get_template_context() + dag_run = context.get("dag_run") + + log = structlog.get_logger(logger_name="task") + + log.debug("Getting previous DAG run", dag_run=dag_run) + + if dag_run is None: + return None + + if dag_run.logical_date is None: + return None + + response = SUPERVISOR_COMMS.send( + msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state) + ) + + if TYPE_CHECKING: + assert isinstance(response, PreviousDagRunResult) + + return response.dag_run + @staticmethod def get_ti_count( dag_id: str, diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 8bd0ea0db8d4d..abe1f1f84a822 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -85,6 +85,8 @@ def get_template_context(self) -> Context: ... def get_first_reschedule_date(self, first_try_number) -> AwareDatetime | None: ... + def get_previous_dagrun(self, state: str | None = None) -> DagRunProtocol | None: ... + @staticmethod def get_ti_count( dag_id: str, diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index d1866f523919c..c3195a7b9bb8b 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -198,6 +198,7 @@ def __call__( def make_ti_context() -> MakeTIContextCallable: """Factory for creating TIRunContext objects.""" from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext + from airflow.utils.state import DagRunState def _make_context( dag_id: str = "test_dag", @@ -226,6 +227,7 @@ def _make_context( start_date=start_date, # type: ignore run_type=run_type, # type: ignore run_after=run_after, # type: ignore + state=DagRunState.RUNNING, conf=conf, # type: ignore consumed_asset_events=list(consumed_asset_events), ), diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index caa515de09a4d..4ffa847983546 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -19,6 +19,7 @@ import json import pickle +from datetime import datetime from unittest import mock import httpx @@ -41,6 +42,7 @@ DeferTask, ErrorResponse, OKResponse, + PreviousDagRunResult, RescheduleTask, TaskRescheduleStartDate, ) @@ -1139,6 +1141,86 @@ def handle_request(request: httpx.Request) -> httpx.Response: result = client.dag_runs.get_count(dag_id="test_dag", run_ids=["run1", "run2"]) assert result.count == 2 + def test_get_previous_basic(self): + """Test basic get_previous functionality with dag_id and logical_date.""" + logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/test_dag/previous": + assert request.url.params["logical_date"] == logical_date.isoformat() + # Return complete DagRun data + return httpx.Response( + status_code=200, + json={ + "dag_id": "test_dag", + "run_id": "prev_run", + "logical_date": "2024-01-14T12:00:00+00:00", + "start_date": "2024-01-14T12:05:00+00:00", + "run_after": "2024-01-14T12:00:00+00:00", + "run_type": "scheduled", + "state": "success", + "consumed_asset_events": [], + }, + ) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_previous(dag_id="test_dag", logical_date=logical_date) + + assert isinstance(result, PreviousDagRunResult) + assert result.dag_run.dag_id == "test_dag" + assert result.dag_run.run_id == "prev_run" + assert result.dag_run.state == "success" + + def test_get_previous_with_state_filter(self): + """Test get_previous functionality with state filtering.""" + logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/test_dag/previous": + assert request.url.params["logical_date"] == logical_date.isoformat() + assert request.url.params["state"] == "success" + # Return complete DagRun data + return httpx.Response( + status_code=200, + json={ + "dag_id": "test_dag", + "run_id": "prev_success_run", + "logical_date": "2024-01-14T12:00:00+00:00", + "start_date": "2024-01-14T12:05:00+00:00", + "run_after": "2024-01-14T12:00:00+00:00", + "run_type": "scheduled", + "state": "success", + "consumed_asset_events": [], + }, + ) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_previous(dag_id="test_dag", logical_date=logical_date, state="success") + + assert isinstance(result, PreviousDagRunResult) + assert result.dag_run.dag_id == "test_dag" + assert result.dag_run.run_id == "prev_success_run" + assert result.dag_run.state == "success" + + def test_get_previous_not_found(self): + """Test get_previous when no previous DAG run exists returns None.""" + logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/test_dag/previous": + assert request.url.params["logical_date"] == logical_date.isoformat() + # Return None (null) when no previous DAG run found + return httpx.Response(status_code=200, content="null") + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_previous(dag_id="test_dag", logical_date=logical_date) + + assert isinstance(result, PreviousDagRunResult) + assert result.dag_run is None + class TestTaskRescheduleOperations: def test_get_start_date(self): diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index 48c7ad74a1501..d0be736b4cea9 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -56,6 +56,7 @@ def test_recv_StartupDetails(self): "run_after": "2024-12-01T01:00:00Z", "end_date": None, "run_type": "manual", + "state": "success", "conf": None, "consumed_asset_events": [], }, diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 600ca4ae54972..511cccc2a33a1 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -51,7 +51,9 @@ AssetEventResponse, AssetProfile, AssetResponse, + DagRun, DagRunState, + DagRunType, TaskInstance, TaskInstanceState, ) @@ -75,6 +77,7 @@ GetConnection, GetDagRunState, GetDRCount, + GetPreviousDagRun, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, GetTaskStates, @@ -85,6 +88,7 @@ GetXComSequenceSlice, InactiveAssetsResult, OKResponse, + PreviousDagRunResult, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, @@ -1822,6 +1826,72 @@ def watched_subprocess(self, mocker): None, id="get_dr_count", ), + pytest.param( + GetPreviousDagRun( + dag_id="test_dag", + logical_date=timezone.parse("2024-01-15T12:00:00Z"), + ), + { + "dag_run": { + "dag_id": "test_dag", + "run_id": "prev_run", + "logical_date": timezone.parse("2024-01-14T12:00:00Z"), + "run_type": "scheduled", + "start_date": timezone.parse("2024-01-15T12:00:00Z"), + "run_after": timezone.parse("2024-01-15T12:00:00Z"), + "consumed_asset_events": [], + "state": "success", + "data_interval_start": None, + "data_interval_end": None, + "end_date": None, + "clear_number": 0, + "conf": None, + }, + "type": "PreviousDagRunResult", + }, + "dag_runs.get_previous", + (), + { + "dag_id": "test_dag", + "logical_date": timezone.parse("2024-01-15T12:00:00Z"), + "state": None, + }, + PreviousDagRunResult( + dag_run=DagRun( + dag_id="test_dag", + run_id="prev_run", + logical_date=timezone.parse("2024-01-14T12:00:00Z"), + run_type=DagRunType.SCHEDULED, + start_date=timezone.parse("2024-01-15T12:00:00Z"), + run_after=timezone.parse("2024-01-15T12:00:00Z"), + consumed_asset_events=[], + state=DagRunState.SUCCESS, + ) + ), + None, + id="get_previous_dagrun", + ), + pytest.param( + GetPreviousDagRun( + dag_id="test_dag", + logical_date=timezone.parse("2024-01-15T12:00:00Z"), + state="success", + ), + { + "dag_run": None, + "type": "PreviousDagRunResult", + }, + "dag_runs.get_previous", + (), + { + "dag_id": "test_dag", + "logical_date": timezone.parse("2024-01-15T12:00:00Z"), + "state": "success", + }, + PreviousDagRunResult(dag_run=None), + None, + id="get_previous_dagrun_with_state", + ), pytest.param( GetTaskStates(dag_id="test_dag", task_group_id="test_group"), { diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 4b445404d8a29..b65c093856eab 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -50,6 +50,7 @@ from airflow.sdk.api.datamodels._generated import ( AssetProfile, AssetResponse, + DagRun, DagRunState, TaskInstance, TaskInstanceState, @@ -71,12 +72,14 @@ GetConnection, GetDagRunState, GetDRCount, + GetPreviousDagRun, GetTaskStates, GetTICount, GetVariable, GetXCom, GetXComSequenceSlice, OKResponse, + PreviousDagRunResult, PrevSuccessfulDagRunResult, SetRenderedFields, SetXCom, @@ -1802,6 +1805,66 @@ def test_get_task_states(self, mock_supervisor_comms): ) assert states == {"run1": {"task1": "running"}} + def test_get_previous_dagrun_basic(self, create_runtime_ti, mock_supervisor_comms): + """Test that get_previous_dagrun sends the correct request without state filter.""" + + task = BaseOperator(task_id="hello") + dag_id = "test_dag" + runtime_ti = create_runtime_ti(task=task, dag_id=dag_id, logical_date=timezone.datetime(2025, 1, 2)) + + dag_run_data = DagRun( + dag_id=dag_id, + run_id="prev_run", + logical_date=timezone.datetime(2025, 1, 1), + start_date=timezone.datetime(2025, 1, 1), + run_after=timezone.datetime(2025, 1, 1), + run_type="scheduled", + state="success", + consumed_asset_events=[], + ) + + mock_supervisor_comms.send.return_value = PreviousDagRunResult(dag_run=dag_run_data) + + dr = runtime_ti.get_previous_dagrun() + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetPreviousDagRun(dag_id="test_dag", logical_date=timezone.datetime(2025, 1, 2), state=None), + ) + assert dr.dag_id == "test_dag" + assert dr.run_id == "prev_run" + assert dr.state == "success" + + def test_get_previous_dagrun_with_state(self, create_runtime_ti, mock_supervisor_comms): + """Test that get_previous_dagrun sends the correct request with state filter.""" + + task = BaseOperator(task_id="hello") + dag_id = "test_dag" + runtime_ti = create_runtime_ti(task=task, dag_id=dag_id, logical_date=timezone.datetime(2025, 1, 2)) + + dag_run_data = DagRun( + dag_id=dag_id, + run_id="prev_success_run", + logical_date=timezone.datetime(2025, 1, 1), + start_date=timezone.datetime(2025, 1, 1), + run_after=timezone.datetime(2025, 1, 1), + run_type="scheduled", + state="success", + consumed_asset_events=[], + ) + + mock_supervisor_comms.send.return_value = PreviousDagRunResult(dag_run=dag_run_data) + + dr = runtime_ti.get_previous_dagrun(state="success") + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetPreviousDagRun( + dag_id="test_dag", logical_date=timezone.datetime(2025, 1, 2), state="success" + ), + ) + assert dr.dag_id == "test_dag" + assert dr.run_id == "prev_success_run" + assert dr.state == "success" + class TestXComAfterTaskExecution: @pytest.mark.parametrize( From b7998a7df074c954fb85811f885fb7c14c2c2108 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 29 Jul 2025 10:34:50 +0530 Subject: [PATCH 09/11] Do not ignore `include_prior_dates` in xcom_pull when `map_indexes` is not specified (#53809) (cherry picked from commit 2a2d3e1b3f60a5f2281859d612d26e0fdb7f4f80) --- .../api_fastapi/execution_api/routes/xcoms.py | 2 + .../execution_api/versions/__init__.py | 7 ++- .../execution_api/versions/v2025_08_10.py | 11 +++++ .../execution_api/versions/head/test_xcoms.py | 49 +++++++++++++++++++ task-sdk/src/airflow/sdk/api/client.py | 3 ++ task-sdk/src/airflow/sdk/bases/xcom.py | 5 ++ .../src/airflow/sdk/execution_time/comms.py | 1 + .../airflow/sdk/execution_time/supervisor.py | 9 +++- .../airflow/sdk/execution_time/task_runner.py | 1 + .../execution_time/test_supervisor.py | 3 +- .../execution_time/test_task_runner.py | 4 +- 11 files changed, 90 insertions(+), 5 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index 9762c7b34985e..eae3539f59317 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -230,6 +230,7 @@ class GetXComSliceFilterParams(BaseModel): start: int | None = None stop: int | None = None step: int | None = None + include_prior_dates: bool = False @router.get( @@ -249,6 +250,7 @@ def get_mapped_xcom_by_slice( key=key, task_ids=task_id, dag_ids=dag_id, + include_prior_dates=params.include_prior_dates, session=session, ) query = query.order_by(None) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index fee781a8ecf35..aeadd1affc92b 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -23,11 +23,16 @@ from airflow.api_fastapi.execution_api.versions.v2025_05_20 import DowngradeUpstreamMapIndexes from airflow.api_fastapi.execution_api.versions.v2025_08_10 import ( AddDagRunStateFieldAndPreviousEndpoint, + AddIncludePriorDatesToGetXComSlice, ) bundle = VersionBundle( HeadVersion(), - Version("2025-08-10", AddDagRunStateFieldAndPreviousEndpoint), + Version( + "2025-08-10", + AddDagRunStateFieldAndPreviousEndpoint, + AddIncludePriorDatesToGetXComSlice, + ), Version("2025-05-20", DowngradeUpstreamMapIndexes), Version("2025-04-28", AddRenderedMapIndexField), Version("2025-04-11"), diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py index 188eaec2d79ce..ec66915e4d908 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py @@ -20,6 +20,7 @@ from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, endpoint, schema from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun, TIRunContext +from airflow.api_fastapi.execution_api.routes.xcoms import GetXComSliceFilterParams class AddDagRunStateFieldAndPreviousEndpoint(VersionChange): @@ -37,3 +38,13 @@ def remove_state_from_dag_run(response: ResponseInfo) -> None: # type: ignore[m """Remove the `state` field from the dag_run object when converting to the previous version.""" if "dag_run" in response.body and isinstance(response.body["dag_run"], dict): response.body["dag_run"].pop("state", None) + + +class AddIncludePriorDatesToGetXComSlice(VersionChange): + """Add the `include_prior_dates` field to GetXComSliceFilterParams.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(GetXComSliceFilterParams).field("include_prior_dates").didnt_exist, + ) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py index 1b10e81cd2338..ea93c2f96e00c 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py @@ -31,6 +31,7 @@ from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.serialization.serde import deserialize, serialize +from airflow.utils import timezone from airflow.utils.session import create_session pytestmark = pytest.mark.db_test @@ -273,6 +274,54 @@ def __init__(self, *, x, **kwargs): assert response.status_code == 200 assert response.json() == ["f", "o", "b"][key] + @pytest.mark.parametrize( + "include_prior_dates, expected_xcoms", + [[True, ["earlier_value", "later_value"]], [False, ["later_value"]]], + ) + def test_xcom_get_slice_accepts_include_prior_dates( + self, client, dag_maker, session, include_prior_dates, expected_xcoms + ): + """Test that the slice endpoint accepts include_prior_dates parameter and works correctly.""" + + with dag_maker(dag_id="dag"): + EmptyOperator(task_id="task") + + earlier_run = dag_maker.create_dagrun( + run_id="earlier_run", logical_date=timezone.parse("2024-01-01T00:00:00Z") + ) + later_run = dag_maker.create_dagrun( + run_id="later_run", logical_date=timezone.parse("2024-01-02T00:00:00Z") + ) + + earlier_ti = earlier_run.get_task_instance("task") + later_ti = later_run.get_task_instance("task") + + earlier_xcom = XComModel( + key="test_key", + value="earlier_value", + dag_run_id=earlier_ti.dag_run.id, + run_id=earlier_ti.run_id, + task_id=earlier_ti.task_id, + dag_id=earlier_ti.dag_id, + ) + later_xcom = XComModel( + key="test_key", + value="later_value", + dag_run_id=later_ti.dag_run.id, + run_id=later_ti.run_id, + task_id=later_ti.task_id, + dag_id=later_ti.dag_id, + ) + session.add_all([earlier_xcom, later_xcom]) + session.commit() + + response = client.get( + f"/execution/xcoms/dag/later_run/task/test_key/slice?include_prior_dates={include_prior_dates}" + ) + assert response.status_code == 200 + + assert response.json() == expected_xcoms + class TestXComsSetEndpoint: @pytest.mark.parametrize( diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index f36c46f45dea2..bbf6eb4dea024 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -491,6 +491,7 @@ def get_sequence_slice( start: int | None, stop: int | None, step: int | None, + include_prior_dates: bool = False, ) -> XComSequenceSliceResponse: params = {} if start is not None: @@ -499,6 +500,8 @@ def get_sequence_slice( params["stop"] = stop if step is not None: params["step"] = step + if include_prior_dates: + params["include_prior_dates"] = include_prior_dates resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice", params=params) return XComSequenceSliceResponse.model_validate_json(resp.read()) diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 423fde1201237..7c982a050ddf4 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -277,6 +277,7 @@ def get_all( dag_id: str, task_id: str, run_id: str, + include_prior_dates: bool = False, ) -> Any: """ Retrieve all XCom values for a task, typically from all map indexes. @@ -291,6 +292,9 @@ def get_all( :param run_id: DAG run ID for the task. :param dag_id: DAG ID to pull XComs from. :param task_id: Task ID to pull XComs from. + :param include_prior_dates: If *False* (default), only XComs from the + specified DAG run are returned. If *True*, the latest matching XComs are + returned regardless of the run they belong to. :return: List of all XCom values if found. """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS @@ -304,6 +308,7 @@ def get_all( start=None, stop=None, step=None, + include_prior_dates=include_prior_dates, ), ) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 1f9eb3141a319..3069490521783 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -692,6 +692,7 @@ class GetXComSequenceSlice(BaseModel): start: int | None stop: int | None step: int | None + include_prior_dates: bool = False type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice" diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index ccac2d8d2ad1c..fffa1ad15553d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1132,7 +1132,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: resp = xcom elif isinstance(msg, GetXComSequenceSlice): xcoms = self.client.xcoms.get_sequence_slice( - msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.start, msg.stop, msg.step + msg.dag_id, + msg.run_id, + msg.task_id, + msg.key, + msg.start, + msg.stop, + msg.step, + msg.include_prior_dates, ) resp = XComSequenceSliceResult.from_response(xcoms) elif isinstance(msg, DeferTask): diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index a6bf2d11e1b13..d4d09c8e20c48 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -361,6 +361,7 @@ def xcom_pull( key=key, task_id=t_id, dag_id=dag_id, + include_prior_dates=include_prior_dates, ) if values is None: diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 511cccc2a33a1..196494c258ddf 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -1954,10 +1954,11 @@ def watched_subprocess(self, mocker): start=None, stop=None, step=None, + include_prior_dates=False, ), {"root": ["foo", "bar"], "type": "XComSequenceSliceResult"}, "xcoms.get_sequence_slice", - ("test_dag", "test_run", "test_task", "test_key", None, None, None), + ("test_dag", "test_run", "test_task", "test_key", None, None, None, False), {}, XComSequenceSliceResult(root=["foo", "bar"]), None, diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index b65c093856eab..2556845dac897 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2049,8 +2049,7 @@ def test_xcom_pull_from_custom_xcom_backend( class CustomOperator(BaseOperator): def execute(self, context): - value = context["ti"].xcom_pull(task_ids="pull_task", key="key") - print(f"Pulled XCom Value: {value}") + context["ti"].xcom_pull(task_ids="pull_task", key="key") task = CustomOperator(task_id="pull_task") runtime_ti = create_runtime_ti(task=task) @@ -2061,6 +2060,7 @@ def execute(self, context): dag_id="test_dag", task_id="pull_task", run_id="test_run", + include_prior_dates=False, ) assert not any( From 9fc671c8cc77eb690cb42ff910c3b6b8bddb40e6 Mon Sep 17 00:00:00 2001 From: Karen Braganza Date: Tue, 5 Aug 2025 02:24:48 -0400 Subject: [PATCH 10/11] Allow setting and deleting variables and xcoms from triggers (#53514) (cherry picked from commit 13fa232b9e04c3446f2e05a9f49e61d5c16af73b) --- .../src/airflow/jobs/triggerer_job_runner.py | 20 ++++ .../tests/unit/jobs/test_triggerer_job.py | 100 +++++++++++++++--- 2 files changed, 104 insertions(+), 16 deletions(-) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 95f8d58bf0d7f..1cf32ac2e4d35 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -47,6 +47,8 @@ CommsDecoder, ConnectionResult, DagRunStateResult, + DeleteVariable, + DeleteXCom, DRCount, ErrorResponse, GetConnection, @@ -56,6 +58,9 @@ GetTICount, GetVariable, GetXCom, + OKResponse, + PutVariable, + SetXCom, TaskStatesResult, TICount, VariableResult, @@ -221,6 +226,7 @@ class TriggerStateSync(BaseModel): TICount, TaskStatesResult, ErrorResponse, + OKResponse, ], Field(discriminator="type"), ] @@ -234,8 +240,12 @@ class TriggerStateSync(BaseModel): Union[ messages.TriggerStateChanges, GetConnection, + DeleteVariable, GetVariable, + PutVariable, + DeleteXCom, GetXCom, + SetXCom, GetTICount, GetTaskStates, GetDagRunState, @@ -400,6 +410,8 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r dump_opts = {"exclude_unset": True, "by_alias": True} else: resp = conn + elif isinstance(msg, DeleteVariable): + resp = self.client.variables.delete(msg.key) elif isinstance(msg, GetVariable): var = self.client.variables.get(msg.key) if isinstance(var, VariableResponse): @@ -408,6 +420,10 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r dump_opts = {"exclude_unset": True} else: resp = var + elif isinstance(msg, PutVariable): + self.client.variables.set(msg.key, msg.value, msg.description) + elif isinstance(msg, DeleteXCom): + self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) elif isinstance(msg, GetXCom): xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) if isinstance(xcom, XComResponse): @@ -416,6 +432,10 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r dump_opts = {"exclude_unset": True} else: resp = xcom + elif isinstance(msg, SetXCom): + self.client.xcoms.set( + msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index, msg.mapped_length + ) elif isinstance(msg, GetDRCount): dr_count = self.client.dag_runs.get_count( dag_id=msg.dag_id, diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index dd1f2b174c838..5e543a161dbd3 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -630,19 +630,64 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: conn = await sync_to_async(BaseHook.get_connection)("test_connection") self.log.info("Loaded conn %s", conn.conn_id) - variable = await sync_to_async(Variable.get)("test_variable") - self.log.info("Loaded variable %s", variable) + get_variable_value = await sync_to_async(Variable.get)("test_get_variable") + self.log.info("Loaded variable %s", get_variable_value) - xcom = await sync_to_async(XCom.get_one)( - key="test_xcom", + get_xcom_value = await sync_to_async(XCom.get_one)( + key="test_get_xcom", dag_id=self.dag_id, run_id=self.run_id, task_id=self.task_id, map_index=self.map_index, ) - self.log.info("Loaded XCom %s", xcom) + self.log.info("Loaded XCom %s", get_xcom_value) - yield TriggerEvent({"connection": attrs.asdict(conn), "variable": variable, "xcom": xcom}) + set_variable_key = "test_set_variable" + set_variable_value = "set_value" + await sync_to_async(Variable.set)(key=set_variable_key, value=set_variable_value) + self.log.info("Set variable with key %s and value %s", set_variable_key, set_variable_value) + + set_xcom_key = "test_set_xcom" + set_xcom_value = "set_xcom" + await sync_to_async(XCom.set)( + key=set_xcom_key, + dag_id=self.dag_id, + run_id=self.run_id, + task_id=self.task_id, + map_index=self.map_index, + value=set_xcom_value, + ) + self.log.info("Set xcom with key %s and value %s", set_xcom_key, set_xcom_value) + + delete_variable_key = "test_delete_variable" + await sync_to_async(Variable.delete)(delete_variable_key) + self.log.info("Deleted variable with key %s", delete_variable_key) + + delete_xcom_key = "test_delete_xcom" + await sync_to_async(XCom.delete)( + key=delete_xcom_key, + dag_id=self.dag_id, + run_id=self.run_id, + task_id=self.task_id, + map_index=self.map_index, + ) + self.log.info("Delete xcom with key %s", delete_xcom_key) + + yield TriggerEvent( + { + "connection": attrs.asdict(conn), + "variable": { + "get_variable": get_variable_value, + "set_variable": set_variable_value, + "delete_variable": delete_variable_key, + }, + "xcom": { + "get_xcom": get_xcom_value, + "set_xcom": set_xcom_value, + "delete_xcom": delete_xcom_key, + }, + } + ) def serialize(self) -> tuple[str, dict[str, Any]]: return ( @@ -669,8 +714,8 @@ def handle_events(self): @pytest.mark.asyncio @pytest.mark.execution_timeout(20) -async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_maker): - """Checks that the trigger will successfully access Variables, Connections and XComs.""" +async def test_trigger_can_call_variables_connections_and_xcoms_methods(session, dag_maker): + """Checks that the trigger will successfully call Variables, Connections and XComs methods.""" # Create the test DAG and task with dag_maker(dag_id="trigger_accessing_variable_connection_and_xcom", session=session): EmptyOperator(task_id="dummy1") @@ -686,7 +731,7 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m kwargs={"dag_id": dr.dag_id, "run_id": dr.run_id, "task_id": task_instance.task_id, "map_index": -1}, ) session.add(trigger_orm) - session.commit() + session.flush() task_instance.trigger_id = trigger_orm.id # Create the appropriate Connection, Variable and XCom @@ -700,9 +745,25 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m port=443, host="example.com", ) - variable = Variable(key="test_variable", val="some_variable_value") + get_variable = Variable(key="test_get_variable", val="some_variable_value") + delete_variable = Variable(key="test_delete_variable", val="delete_value") + + session.add(connection) + session.add(get_variable) + session.add(delete_variable) + XComModel.set( - key="test_xcom", + key="test_get_xcom", + value="some_xcom_value", + task_id=task_instance.task_id, + dag_id=dr.dag_id, + run_id=dr.run_id, + map_index=-1, + session=session, + ) + + XComModel.set( + key="test_delete_xcom", value="some_xcom_value", task_id=task_instance.task_id, dag_id=dr.dag_id, @@ -710,8 +771,6 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m map_index=-1, session=session, ) - session.add(connection) - session.add(variable) job = Job() session.add(job) @@ -723,7 +782,7 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m task_instance.refresh_from_db() assert task_instance.state == TaskInstanceState.SCHEDULED assert task_instance.next_method != "__fail__" - assert task_instance.next_kwargs == { + expected_event = { "event": { "connection": { "conn_id": "test_connection", @@ -736,10 +795,19 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m "port": 443, "extra": '{"key": "value"}', }, - "variable": "some_variable_value", - "xcom": '"some_xcom_value"', + "variable": { + "get_variable": "some_variable_value", + "set_variable": "set_value", + "delete_variable": "test_delete_variable", + }, + "xcom": { + "get_xcom": '"some_xcom_value"', + "set_xcom": "set_xcom", + "delete_xcom": "test_delete_xcom", + }, } } + assert task_instance.next_kwargs == expected_event class CustomTriggerDagRun(BaseTrigger): From 34b2f131b3a08c288dffec28e13625293f09ffbb Mon Sep 17 00:00:00 2001 From: Yaming Zhang Date: Wed, 6 Aug 2025 02:06:25 -0700 Subject: [PATCH 11/11] Fix type error with TIH when reading served log (#54114) (cherry picked from commit 3df18a07a1c2bb0f85ebdf7d889d687c70d39b47) --- .../src/airflow/utils/log/file_task_handler.py | 4 ++-- airflow-core/tests/unit/utils/test_log_handlers.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/utils/log/file_task_handler.py b/airflow-core/src/airflow/utils/log/file_task_handler.py index 4aa802f37c890..37149b05d6cd8 100644 --- a/airflow-core/src/airflow/utils/log/file_task_handler.py +++ b/airflow-core/src/airflow/utils/log/file_task_handler.py @@ -867,13 +867,13 @@ def _read_from_local( def _read_from_logs_server( self, - ti: TaskInstance, + ti: TaskInstance | TaskInstanceHistory, worker_log_rel_path: str, ) -> LogResponse: sources: LogSourceInfo = [] log_streams: list[RawLogStream] = [] try: - log_type = LogType.TRIGGER if ti.triggerer_job else LogType.WORKER + log_type = LogType.TRIGGER if getattr(ti, "triggerer_job", False) else LogType.WORKER url, rel_path = self._get_log_retrieval_url(ti, worker_log_rel_path, log_type=log_type) response = _fetch_logs_from_service(url, rel_path) if response.status_code == 403: diff --git a/airflow-core/tests/unit/utils/test_log_handlers.py b/airflow-core/tests/unit/utils/test_log_handlers.py index b4be1c3f895cf..d3a2512a06778 100644 --- a/airflow-core/tests/unit/utils/test_log_handlers.py +++ b/airflow-core/tests/unit/utils/test_log_handlers.py @@ -557,6 +557,18 @@ def test__read_served_logs_checked_when_done_and_no_local_or_remote_logs( assert extract_events(logs, False) == expected_logs assert metadata == {"end_of_log": True, "log_pos": 3} + @pytest.mark.parametrize("is_tih", [False, True]) + def test_read_served_logs(self, is_tih, create_task_instance): + ti = create_task_instance( + state=TaskInstanceState.SUCCESS, + hostname="test_hostname", + ) + if is_tih: + ti = TaskInstanceHistory(ti, ti.state) + fth = FileTaskHandler("") + sources, _ = fth._read_from_logs_server(ti, "test.log") + assert len(sources) > 0 + def test_add_triggerer_suffix(self): sample = "any/path/to/thing.txt" assert FileTaskHandler.add_triggerer_suffix(sample) == sample + ".trigger"