diff --git a/airflow-core/docs/administration-and-deployment/dag-bundles.rst b/airflow-core/docs/administration-and-deployment/dag-bundles.rst index 057d354b1f523..48401b0c1439c 100644 --- a/airflow-core/docs/administration-and-deployment/dag-bundles.rst +++ b/airflow-core/docs/administration-and-deployment/dag-bundles.rst @@ -50,6 +50,9 @@ Airflow supports multiple types of dag Bundles, each catering to specific use ca **airflow.providers.git.bundles.git.GitDagBundle** These bundles integrate with Git repositories, allowing Airflow to fetch dags directly from a repository. +**airflow.providers.amazon.aws.bundles.s3.S3DagBundle** + These bundles reference an S3 bucket containing DAG files. They do not support versioning of the bundle, meaning tasks always run using the latest code. + Configuring dag bundles ----------------------- @@ -65,7 +68,7 @@ For example, adding multiple dag bundles to your ``airflow.cfg`` file: dag_bundle_config_list = [ { "name": "my_git_repo", - "classpath": "airflow.dag_processing.bundles.git.GitDagBundle", + "classpath": "airflow.providers.git.bundles.git.GitDagBundle", "kwargs": {"tracking_ref": "main", "git_conn_id": "my_git_conn"} }, { diff --git a/airflow-core/docs/administration-and-deployment/logging-monitoring/advanced-logging-configuration.rst b/airflow-core/docs/administration-and-deployment/logging-monitoring/advanced-logging-configuration.rst index 342de5e016a3b..fa34e74f3140b 100644 --- a/airflow-core/docs/administration-and-deployment/logging-monitoring/advanced-logging-configuration.rst +++ b/airflow-core/docs/administration-and-deployment/logging-monitoring/advanced-logging-configuration.rst @@ -34,7 +34,7 @@ Some configuration options require that the logging config class be overwritten. configuration of Airflow and modifying it to suit your needs. The default configuration can be seen in the -`airflow_local_settings.py template `_ +`airflow_local_settings.py template `_ and you can see the loggers and handlers used there. See :ref:`Configuring local settings ` for details on how to diff --git a/airflow-core/docs/installation/upgrading_to_airflow3.rst b/airflow-core/docs/installation/upgrading_to_airflow3.rst index 560f352ee4138..8b733ad27f061 100644 --- a/airflow-core/docs/installation/upgrading_to_airflow3.rst +++ b/airflow-core/docs/installation/upgrading_to_airflow3.rst @@ -108,9 +108,25 @@ Some changes can be automatically fixed. To do so, run the following command: ruff check dag/ --select AIR301 --fix --preview +Some of the fixes are marked as unsafe. Unsafe fixes usually do not break dag code. They're marked as unsafe as they may change some runtime behavior. For more information, see `Fix Safety `_. +To trigger these fixes, run the following command: + +.. code-block:: bash + + ruff check dags/ --select AIR301 --fix --unsafe-fixes --preview + +.. note:: + Ruff has strict policy about when a rule becomes stable. Till it does you must use --preview flag. + The progress of Airflow Ruff rule become stable can be tracked in https://github.com/astral-sh/ruff/issues/17749 + That said, from Airflow side the rules are perfectly fine to be used. + +.. note:: + + In AIR rules, unsafe fixes involve changing import paths while keeping the name of the imported member the same. For instance, changing the import from ``from airflow.sensors.base_sensor_operator import BaseSensorOperator`` to ``from airflow.sdk.bases.sensor import BaseSensorOperator`` requires ruff to remove the original import before adding the new one. In contrast, safe fixes include changes to both the member name and the import path, such as changing ``from airflow.datasets import Dataset`` to `from airflow.sdk import Asset``. These adjustments do not require ruff to remove the old import. To remove unused legacy imports, it is necessary to enable the `unused-import` rule (F401) . + You can also configure these flags through configuration files. See `Configuring Ruff `_ for details. -Step 4: Install the Standard Providers +Step 4: Install the Standard Provider -------------------------------------- - Some of the commonly used Operators which were bundled as part of the ``airflow-core`` package (for example ``BashOperator`` and ``PythonOperator``) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 2d7968bbc625e..f1e42ef1e0b72 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -35,7 +35,12 @@ from airflow.api_fastapi.execution_api.datamodels.asset import AssetProfile from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse -from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState +from airflow.utils.state import ( + DagRunState, + IntermediateTIState, + TaskInstanceState as TIState, + TerminalTIState, +) from airflow.utils.types import DagRunType AwareDatetimeAdapter = TypeAdapter(AwareDatetime) @@ -292,6 +297,7 @@ class DagRun(StrictBaseModel): end_date: UtcDateTime | None clear_number: int = 0 run_type: DagRunType + state: DagRunState conf: Annotated[dict[str, Any], Field(default_factory=dict)] consumed_asset_events: list[AssetEventDagRunReference] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py index db27f7bd93ed8..d0dcef4faebe2 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py @@ -27,10 +27,12 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.execution_api.datamodels.dagrun import DagRunStateResponse, TriggerDAGRunPayload +from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun from airflow.exceptions import DagRunAlreadyExists from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag -from airflow.models.dagrun import DagRun +from airflow.models.dagrun import DagRun as DagRunModel +from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType router = APIRouter() @@ -140,7 +142,9 @@ def get_dagrun_state( session: SessionDep, ) -> DagRunStateResponse: """Get a DAG Run State.""" - dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)) + dag_run = session.scalar( + select(DagRunModel).where(DagRunModel.dag_id == dag_id, DagRunModel.run_id == run_id) + ) if dag_run is None: raise HTTPException( status.HTTP_404_NOT_FOUND, @@ -162,16 +166,45 @@ def get_dr_count( states: Annotated[list[str] | None, Query()] = None, ) -> int: """Get the count of DAG runs matching the given criteria.""" - query = select(func.count()).select_from(DagRun).where(DagRun.dag_id == dag_id) + query = select(func.count()).select_from(DagRunModel).where(DagRunModel.dag_id == dag_id) if logical_dates: - query = query.where(DagRun.logical_date.in_(logical_dates)) + query = query.where(DagRunModel.logical_date.in_(logical_dates)) if run_ids: - query = query.where(DagRun.run_id.in_(run_ids)) + query = query.where(DagRunModel.run_id.in_(run_ids)) if states: - query = query.where(DagRun.state.in_(states)) + query = query.where(DagRunModel.state.in_(states)) count = session.scalar(query) return count or 0 + + +@router.get("/{dag_id}/previous", status_code=status.HTTP_200_OK) +def get_previous_dagrun( + dag_id: str, + logical_date: UtcDateTime, + session: SessionDep, + state: Annotated[DagRunState | None, Query()] = None, +) -> DagRun | None: + """Get the previous DAG run before the given logical date, optionally filtered by state.""" + query = ( + select(DagRunModel) + .where( + DagRunModel.dag_id == dag_id, + DagRunModel.logical_date < logical_date, + ) + .order_by(DagRunModel.logical_date.desc()) + .limit(1) + ) + + if state: + query = query.where(DagRunModel.state == state) + + dag_run = session.scalar(query) + + if not dag_run: + return None + + return DagRun.model_validate(dag_run) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index 9762c7b34985e..eae3539f59317 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -230,6 +230,7 @@ class GetXComSliceFilterParams(BaseModel): start: int | None = None stop: int | None = None step: int | None = None + include_prior_dates: bool = False @router.get( @@ -249,6 +250,7 @@ def get_mapped_xcom_by_slice( key=key, task_ids=task_id, dag_ids=dag_id, + include_prior_dates=params.include_prior_dates, session=session, ) query = query.order_by(None) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 5462f10297495..aeadd1affc92b 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -21,9 +21,18 @@ from airflow.api_fastapi.execution_api.versions.v2025_04_28 import AddRenderedMapIndexField from airflow.api_fastapi.execution_api.versions.v2025_05_20 import DowngradeUpstreamMapIndexes +from airflow.api_fastapi.execution_api.versions.v2025_08_10 import ( + AddDagRunStateFieldAndPreviousEndpoint, + AddIncludePriorDatesToGetXComSlice, +) bundle = VersionBundle( HeadVersion(), + Version( + "2025-08-10", + AddDagRunStateFieldAndPreviousEndpoint, + AddIncludePriorDatesToGetXComSlice, + ), Version("2025-05-20", DowngradeUpstreamMapIndexes), Version("2025-04-28", AddRenderedMapIndexField), Version("2025-04-11"), diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py new file mode 100644 index 0000000000000..ec66915e4d908 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, endpoint, schema + +from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun, TIRunContext +from airflow.api_fastapi.execution_api.routes.xcoms import GetXComSliceFilterParams + + +class AddDagRunStateFieldAndPreviousEndpoint(VersionChange): + """Add the `state` field to DagRun model and `/dag-runs/{dag_id}/previous` endpoint.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(DagRun).field("state").didnt_exist, + endpoint("/dag-runs/{dag_id}/previous", ["GET"]).didnt_exist, + ) + + @convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type] + def remove_state_from_dag_run(response: ResponseInfo) -> None: # type: ignore[misc] + """Remove the `state` field from the dag_run object when converting to the previous version.""" + if "dag_run" in response.body and isinstance(response.body["dag_run"], dict): + response.body["dag_run"].pop("state", None) + + +class AddIncludePriorDatesToGetXComSlice(VersionChange): + """Add the `include_prior_dates` field to GetXComSliceFilterParams.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(GetXComSliceFilterParams).field("include_prior_dates").didnt_exist, + ) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 95f8d58bf0d7f..1cf32ac2e4d35 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -47,6 +47,8 @@ CommsDecoder, ConnectionResult, DagRunStateResult, + DeleteVariable, + DeleteXCom, DRCount, ErrorResponse, GetConnection, @@ -56,6 +58,9 @@ GetTICount, GetVariable, GetXCom, + OKResponse, + PutVariable, + SetXCom, TaskStatesResult, TICount, VariableResult, @@ -221,6 +226,7 @@ class TriggerStateSync(BaseModel): TICount, TaskStatesResult, ErrorResponse, + OKResponse, ], Field(discriminator="type"), ] @@ -234,8 +240,12 @@ class TriggerStateSync(BaseModel): Union[ messages.TriggerStateChanges, GetConnection, + DeleteVariable, GetVariable, + PutVariable, + DeleteXCom, GetXCom, + SetXCom, GetTICount, GetTaskStates, GetDagRunState, @@ -400,6 +410,8 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r dump_opts = {"exclude_unset": True, "by_alias": True} else: resp = conn + elif isinstance(msg, DeleteVariable): + resp = self.client.variables.delete(msg.key) elif isinstance(msg, GetVariable): var = self.client.variables.get(msg.key) if isinstance(var, VariableResponse): @@ -408,6 +420,10 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r dump_opts = {"exclude_unset": True} else: resp = var + elif isinstance(msg, PutVariable): + self.client.variables.set(msg.key, msg.value, msg.description) + elif isinstance(msg, DeleteXCom): + self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) elif isinstance(msg, GetXCom): xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) if isinstance(xcom, XComResponse): @@ -416,6 +432,10 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r dump_opts = {"exclude_unset": True} else: resp = xcom + elif isinstance(msg, SetXCom): + self.client.xcoms.set( + msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index, msg.mapped_length + ) elif isinstance(msg, GetDRCount): dr_count = self.client.dag_runs.get_count( dag_id=msg.dag_id, diff --git a/airflow-core/src/airflow/models/__init__.py b/airflow-core/src/airflow/models/__init__.py index 9274ae7a79f3f..b5a8e61e49ae7 100644 --- a/airflow-core/src/airflow/models/__init__.py +++ b/airflow-core/src/airflow/models/__init__.py @@ -19,6 +19,8 @@ from __future__ import annotations +from airflow.utils.deprecation_tools import add_deprecated_classes + # Do not add new models to this -- this is for compat only __all__ = [ "DAG", @@ -141,3 +143,22 @@ def __getattr__(name): from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions.param import Param from airflow.sdk.execution_time.xcom import XCom + +__deprecated_classes = { + "abstractoperator": { + "AbstractOperator": "airflow.sdk.definitions._internal.abstractoperator.AbstractOperator", + "NotMapped": "airflow.sdk.definitions._internal.abstractoperator.NotMapped", + "TaskStateChangeCallback": "airflow.sdk.definitions._internal.abstractoperator.TaskStateChangeCallback", + "DEFAULT_OWNER": "airflow.sdk.definitions._internal.abstractoperator.DEFAULT_OWNER", + "DEFAULT_QUEUE": "airflow.sdk.definitions._internal.abstractoperator.DEFAULT_QUEUE", + "DEFAULT_TASK_EXECUTION_TIMEOUT": "airflow.sdk.definitions._internal.abstractoperator.DEFAULT_TASK_EXECUTION_TIMEOUT", + }, + "param": { + "Param": "airflow.sdk.definitions.param.Param", + "ParamsDict": "airflow.sdk.definitions.param.ParamsDict", + }, + "baseoperatorlink": { + "BaseOperatorLink": "airflow.sdk.bases.operatorlink.BaseOperatorLink", + }, +} +add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow-core/src/airflow/models/abstractoperator.py b/airflow-core/src/airflow/models/abstractoperator.py deleted file mode 100644 index e5b5f7dc81f45..0000000000000 --- a/airflow-core/src/airflow/models/abstractoperator.py +++ /dev/null @@ -1,34 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import datetime - -from airflow.configuration import conf -from airflow.sdk.definitions._internal.abstractoperator import ( - AbstractOperator as AbstractOperator, - NotMapped as NotMapped, # Re-export this for compat - TaskStateChangeCallback as TaskStateChangeCallback, -) - -DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") -DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") - -DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( - "core", "default_task_execution_timeout" -) diff --git a/airflow-core/src/airflow/utils/log/file_task_handler.py b/airflow-core/src/airflow/utils/log/file_task_handler.py index 4aa802f37c890..37149b05d6cd8 100644 --- a/airflow-core/src/airflow/utils/log/file_task_handler.py +++ b/airflow-core/src/airflow/utils/log/file_task_handler.py @@ -867,13 +867,13 @@ def _read_from_local( def _read_from_logs_server( self, - ti: TaskInstance, + ti: TaskInstance | TaskInstanceHistory, worker_log_rel_path: str, ) -> LogResponse: sources: LogSourceInfo = [] log_streams: list[RawLogStream] = [] try: - log_type = LogType.TRIGGER if ti.triggerer_job else LogType.WORKER + log_type = LogType.TRIGGER if getattr(ti, "triggerer_job", False) else LogType.WORKER url, rel_path = self._get_log_retrieval_url(ti, worker_log_rel_path, log_type=log_type) response = _fetch_logs_from_service(url, rel_path) if response.status_code == 403: diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py index f9f8d489d3d26..ac414f53ee83a 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py @@ -314,3 +314,142 @@ def test_get_count_with_mixed_states(self, client, session, dag_maker): ) assert response.status_code == 200 assert response.json() == 2 + + +class TestGetPreviousDagRun: + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + def test_get_previous_dag_run_basic(self, client, session, dag_maker): + """Test getting the previous DAG run without state filtering.""" + dag_id = "test_get_previous_basic" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + # Create multiple DAG runs + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS + ) + dag_maker.create_dagrun( + run_id="run2", logical_date=timezone.datetime(2025, 1, 5), state=DagRunState.FAILED + ) + dag_maker.create_dagrun( + run_id="run3", logical_date=timezone.datetime(2025, 1, 10), state=DagRunState.SUCCESS + ) + session.commit() + + # Query for previous DAG run before 2025-01-10 + response = client.get( + f"/execution/dag-runs/{dag_id}/previous", + params={ + "logical_date": timezone.datetime(2025, 1, 10).isoformat(), + }, + ) + + assert response.status_code == 200 + result = response.json() + assert result["dag_id"] == dag_id + assert result["run_id"] == "run2" # Most recent before 2025-01-10 + assert result["state"] == "failed" + + def test_get_previous_dag_run_with_state_filter(self, client, session, dag_maker): + """Test getting the previous DAG run with state filtering.""" + dag_id = "test_get_previous_with_state" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + # Create multiple DAG runs with different states + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS + ) + dag_maker.create_dagrun( + run_id="run2", logical_date=timezone.datetime(2025, 1, 5), state=DagRunState.FAILED + ) + dag_maker.create_dagrun( + run_id="run3", logical_date=timezone.datetime(2025, 1, 8), state=DagRunState.SUCCESS + ) + session.commit() + + # Query for previous successful DAG run before 2025-01-10 + response = client.get( + f"/execution/dag-runs/{dag_id}/previous", + params={"logical_date": timezone.datetime(2025, 1, 10).isoformat(), "state": "success"}, + ) + + assert response.status_code == 200 + result = response.json() + assert result["dag_id"] == dag_id + assert result["run_id"] == "run3" # Most recent successful run before 2025-01-10 + assert result["state"] == "success" + + def test_get_previous_dag_run_no_previous_found(self, client, session, dag_maker): + """Test getting previous DAG run when none exists returns null.""" + dag_id = "test_get_previous_none" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + # Create only one DAG run - no previous should exist + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS + ) + + response = client.get(f"/execution/dag-runs/{dag_id}/previous?logical_date=2025-01-01T00:00:00Z") + + assert response.status_code == 200 + assert response.json() is None # Should return null + + def test_get_previous_dag_run_no_matching_state(self, client, session, dag_maker): + """Test getting previous DAG run with state filter that matches nothing returns null.""" + dag_id = "test_get_previous_no_match" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + # Create DAG runs with different states + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.FAILED + ) + dag_maker.create_dagrun( + run_id="run2", logical_date=timezone.datetime(2025, 1, 2), state=DagRunState.FAILED + ) + + # Look for previous success but only failed runs exist + response = client.get( + f"/execution/dag-runs/{dag_id}/previous?logical_date=2025-01-03T00:00:00Z&state=success" + ) + + assert response.status_code == 200 + assert response.json() is None + + def test_get_previous_dag_run_dag_not_found(self, client, session): + """Test getting previous DAG run for non-existent DAG returns 404.""" + response = client.get( + "/execution/dag-runs/nonexistent_dag/previous?logical_date=2025-01-01T00:00:00Z" + ) + + assert response.status_code == 200 + assert response.json() is None + + def test_get_previous_dag_run_invalid_state_parameter(self, client, session, dag_maker): + """Test that invalid state parameter returns 422 validation error.""" + dag_id = "test_get_previous_invalid_state" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + dag_maker.create_dagrun( + run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS + ) + session.commit() + + response = client.get( + f"/execution/dag-runs/{dag_id}/previous?logical_date=2025-01-02T00:00:00Z&state=invalid_state" + ) + + assert response.status_code == 422 diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 3c107f61863f2..4f1e28a11f628 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -35,7 +35,7 @@ from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import Asset, TaskGroup, task, task_group from airflow.utils import timezone -from airflow.utils.state import State, TaskInstanceState, TerminalTIState +from airflow.utils.state import DagRunState, State, TaskInstanceState, TerminalTIState from tests_common.test_utils.db import ( clear_db_assets, @@ -155,6 +155,7 @@ def test_ti_run_state_to_running( ti = create_task_instance( task_id="test_ti_run_state_to_running", state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, session=session, start_date=instant, dag_id=str(uuid4()), @@ -184,6 +185,7 @@ def test_ti_run_state_to_running( "data_interval_end": instant_str, "run_after": instant_str, "start_date": instant_str, + "state": "running", "end_date": None, "run_type": "manual", "conf": {}, diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py index 1b10e81cd2338..ea93c2f96e00c 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py @@ -31,6 +31,7 @@ from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.serialization.serde import deserialize, serialize +from airflow.utils import timezone from airflow.utils.session import create_session pytestmark = pytest.mark.db_test @@ -273,6 +274,54 @@ def __init__(self, *, x, **kwargs): assert response.status_code == 200 assert response.json() == ["f", "o", "b"][key] + @pytest.mark.parametrize( + "include_prior_dates, expected_xcoms", + [[True, ["earlier_value", "later_value"]], [False, ["later_value"]]], + ) + def test_xcom_get_slice_accepts_include_prior_dates( + self, client, dag_maker, session, include_prior_dates, expected_xcoms + ): + """Test that the slice endpoint accepts include_prior_dates parameter and works correctly.""" + + with dag_maker(dag_id="dag"): + EmptyOperator(task_id="task") + + earlier_run = dag_maker.create_dagrun( + run_id="earlier_run", logical_date=timezone.parse("2024-01-01T00:00:00Z") + ) + later_run = dag_maker.create_dagrun( + run_id="later_run", logical_date=timezone.parse("2024-01-02T00:00:00Z") + ) + + earlier_ti = earlier_run.get_task_instance("task") + later_ti = later_run.get_task_instance("task") + + earlier_xcom = XComModel( + key="test_key", + value="earlier_value", + dag_run_id=earlier_ti.dag_run.id, + run_id=earlier_ti.run_id, + task_id=earlier_ti.task_id, + dag_id=earlier_ti.dag_id, + ) + later_xcom = XComModel( + key="test_key", + value="later_value", + dag_run_id=later_ti.dag_run.id, + run_id=later_ti.run_id, + task_id=later_ti.task_id, + dag_id=later_ti.dag_id, + ) + session.add_all([earlier_xcom, later_xcom]) + session.commit() + + response = client.get( + f"/execution/xcoms/dag/later_run/task/test_key/slice?include_prior_dates={include_prior_dates}" + ) + assert response.status_code == 200 + + assert response.json() == expected_xcoms + class TestXComsSetEndpoint: @pytest.mark.parametrize( diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 1a3925c077db3..dc85bfb1b21fa 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -57,6 +57,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.sdk import DAG from airflow.sdk.api.client import Client +from airflow.sdk.api.datamodels._generated import DagRunState from airflow.sdk.execution_time import comms from airflow.utils import timezone from airflow.utils.session import create_session @@ -957,6 +958,7 @@ def fake_collect_dags(self, *args, **kwargs): logical_date=timezone.utcnow(), start_date=timezone.utcnow(), run_type="manual", + state=DagRunState.RUNNING, ) dag_run.run_after = timezone.utcnow() diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index dd1f2b174c838..5e543a161dbd3 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -630,19 +630,64 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: conn = await sync_to_async(BaseHook.get_connection)("test_connection") self.log.info("Loaded conn %s", conn.conn_id) - variable = await sync_to_async(Variable.get)("test_variable") - self.log.info("Loaded variable %s", variable) + get_variable_value = await sync_to_async(Variable.get)("test_get_variable") + self.log.info("Loaded variable %s", get_variable_value) - xcom = await sync_to_async(XCom.get_one)( - key="test_xcom", + get_xcom_value = await sync_to_async(XCom.get_one)( + key="test_get_xcom", dag_id=self.dag_id, run_id=self.run_id, task_id=self.task_id, map_index=self.map_index, ) - self.log.info("Loaded XCom %s", xcom) + self.log.info("Loaded XCom %s", get_xcom_value) - yield TriggerEvent({"connection": attrs.asdict(conn), "variable": variable, "xcom": xcom}) + set_variable_key = "test_set_variable" + set_variable_value = "set_value" + await sync_to_async(Variable.set)(key=set_variable_key, value=set_variable_value) + self.log.info("Set variable with key %s and value %s", set_variable_key, set_variable_value) + + set_xcom_key = "test_set_xcom" + set_xcom_value = "set_xcom" + await sync_to_async(XCom.set)( + key=set_xcom_key, + dag_id=self.dag_id, + run_id=self.run_id, + task_id=self.task_id, + map_index=self.map_index, + value=set_xcom_value, + ) + self.log.info("Set xcom with key %s and value %s", set_xcom_key, set_xcom_value) + + delete_variable_key = "test_delete_variable" + await sync_to_async(Variable.delete)(delete_variable_key) + self.log.info("Deleted variable with key %s", delete_variable_key) + + delete_xcom_key = "test_delete_xcom" + await sync_to_async(XCom.delete)( + key=delete_xcom_key, + dag_id=self.dag_id, + run_id=self.run_id, + task_id=self.task_id, + map_index=self.map_index, + ) + self.log.info("Delete xcom with key %s", delete_xcom_key) + + yield TriggerEvent( + { + "connection": attrs.asdict(conn), + "variable": { + "get_variable": get_variable_value, + "set_variable": set_variable_value, + "delete_variable": delete_variable_key, + }, + "xcom": { + "get_xcom": get_xcom_value, + "set_xcom": set_xcom_value, + "delete_xcom": delete_xcom_key, + }, + } + ) def serialize(self) -> tuple[str, dict[str, Any]]: return ( @@ -669,8 +714,8 @@ def handle_events(self): @pytest.mark.asyncio @pytest.mark.execution_timeout(20) -async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_maker): - """Checks that the trigger will successfully access Variables, Connections and XComs.""" +async def test_trigger_can_call_variables_connections_and_xcoms_methods(session, dag_maker): + """Checks that the trigger will successfully call Variables, Connections and XComs methods.""" # Create the test DAG and task with dag_maker(dag_id="trigger_accessing_variable_connection_and_xcom", session=session): EmptyOperator(task_id="dummy1") @@ -686,7 +731,7 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m kwargs={"dag_id": dr.dag_id, "run_id": dr.run_id, "task_id": task_instance.task_id, "map_index": -1}, ) session.add(trigger_orm) - session.commit() + session.flush() task_instance.trigger_id = trigger_orm.id # Create the appropriate Connection, Variable and XCom @@ -700,9 +745,25 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m port=443, host="example.com", ) - variable = Variable(key="test_variable", val="some_variable_value") + get_variable = Variable(key="test_get_variable", val="some_variable_value") + delete_variable = Variable(key="test_delete_variable", val="delete_value") + + session.add(connection) + session.add(get_variable) + session.add(delete_variable) + XComModel.set( - key="test_xcom", + key="test_get_xcom", + value="some_xcom_value", + task_id=task_instance.task_id, + dag_id=dr.dag_id, + run_id=dr.run_id, + map_index=-1, + session=session, + ) + + XComModel.set( + key="test_delete_xcom", value="some_xcom_value", task_id=task_instance.task_id, dag_id=dr.dag_id, @@ -710,8 +771,6 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m map_index=-1, session=session, ) - session.add(connection) - session.add(variable) job = Job() session.add(job) @@ -723,7 +782,7 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m task_instance.refresh_from_db() assert task_instance.state == TaskInstanceState.SCHEDULED assert task_instance.next_method != "__fail__" - assert task_instance.next_kwargs == { + expected_event = { "event": { "connection": { "conn_id": "test_connection", @@ -736,10 +795,19 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m "port": 443, "extra": '{"key": "value"}', }, - "variable": "some_variable_value", - "xcom": '"some_xcom_value"', + "variable": { + "get_variable": "some_variable_value", + "set_variable": "set_value", + "delete_variable": "test_delete_variable", + }, + "xcom": { + "get_xcom": '"some_xcom_value"', + "set_xcom": "set_xcom", + "delete_xcom": "test_delete_xcom", + }, } } + assert task_instance.next_kwargs == expected_event class CustomTriggerDagRun(BaseTrigger): diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index a7508b3ea520d..f4a719496f27d 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -187,6 +187,7 @@ "bash_command": "echo {{ task.task_id }}", "task_type": "BashOperator", "_task_module": "airflow.providers.standard.operators.bash", + "owner": "airflow", "pool": "default_pool", "is_setup": False, "is_teardown": False, @@ -3163,6 +3164,7 @@ def test_handle_v1_serdag(): "_task_type": "BashOperator", # Slightly difference from v2-10-stable here, we manually changed this path "_task_module": "airflow.providers.standard.operators.bash", + "owner": "airflow", "pool": "default_pool", "is_setup": False, "is_teardown": False, diff --git a/airflow-core/tests/unit/utils/test_log_handlers.py b/airflow-core/tests/unit/utils/test_log_handlers.py index b4be1c3f895cf..d3a2512a06778 100644 --- a/airflow-core/tests/unit/utils/test_log_handlers.py +++ b/airflow-core/tests/unit/utils/test_log_handlers.py @@ -557,6 +557,18 @@ def test__read_served_logs_checked_when_done_and_no_local_or_remote_logs( assert extract_events(logs, False) == expected_logs assert metadata == {"end_of_log": True, "log_pos": 3} + @pytest.mark.parametrize("is_tih", [False, True]) + def test_read_served_logs(self, is_tih, create_task_instance): + ti = create_task_instance( + state=TaskInstanceState.SUCCESS, + hostname="test_hostname", + ) + if is_tih: + ti = TaskInstanceHistory(ti, ti.state) + fth = FileTaskHandler("") + sources, _ = fth._read_from_logs_server(ti, "test.log") + assert len(sources) > 0 + def test_add_triggerer_suffix(self): sample = "any/path/to/thing.txt" assert FileTaskHandler.add_triggerer_suffix(sample) == sample + ".trigger" diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 2288fefc26880..e161771b84aad 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2134,7 +2134,7 @@ def _create_task_instance( should_retry: bool | None = None, max_tries: int | None = None, ) -> RuntimeTaskInstance: - from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext + from airflow.sdk.api.datamodels._generated import DagRun, DagRunState, TIRunContext from airflow.utils.types import DagRunType if not ti_id: @@ -2167,17 +2167,20 @@ def _create_task_instance( run_after = data_interval_end or logical_date or timezone.utcnow() ti_context = TIRunContext( - dag_run=DagRun( - dag_id=dag_id, - run_id=run_id, - logical_date=logical_date, # type: ignore - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, - start_date=start_date, # type: ignore - run_type=run_type, # type: ignore - run_after=run_after, # type: ignore - conf=conf, - consumed_asset_events=[], + dag_run=DagRun.model_validate( + { + "dag_id": dag_id, + "run_id": run_id, + "logical_date": logical_date, # type: ignore + "data_interval_start": data_interval_start, + "data_interval_end": data_interval_end, + "start_date": start_date, # type: ignore + "run_type": run_type, # type: ignore + "run_after": run_after, # type: ignore + "conf": conf, + "consumed_asset_events": [], + **({"state": DagRunState.RUNNING} if "state" in DagRun.model_fields else {}), + } ), task_reschedule_count=task_reschedule_count, max_tries=task_retries if max_tries is None else max_tries, diff --git a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py index 55d5414cd8965..20bf71623d179 100644 --- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py +++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py @@ -30,7 +30,11 @@ from airflow.cli.cli_config import GroupCommand from airflow.configuration import conf from airflow.executors.base_executor import BaseExecutor -from airflow.models.abstractoperator import DEFAULT_QUEUE + +try: + from airflow.models.abstractoperator import DEFAULT_QUEUE +except (ImportError, AttributeError): + from airflow.sdk.definitions._internal.abstractoperator import DEFAULT_QUEUE from airflow.models.taskinstance import TaskInstance, TaskInstanceState from airflow.providers.edge3.cli.edge_command import EDGE_COMMANDS from airflow.providers.edge3.models.edge_job import EdgeJobModel diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py index eff8dc9efcb88..555725f0246a3 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py @@ -853,6 +853,7 @@ def _create_listener_and_task_instance( task_instance.dag_run.clear_number = 0 task_instance.dag_run.logical_date = timezone.datetime(2020, 1, 1, 1, 1, 1) task_instance.dag_run.run_after = timezone.datetime(2020, 1, 1, 1, 1, 1) + task_instance.dag_run.state = DagRunState.RUNNING task_instance.task = None task_instance.dag = None task_instance.task_id = "task_id" @@ -862,6 +863,7 @@ def _create_listener_and_task_instance( # RuntimeTaskInstance is used when on worker from airflow.sdk.api.datamodels._generated import ( DagRun as SdkDagRun, + DagRunState as SdkDagRunState, DagRunType, TaskInstance as SdkTaskInstance, TIRunContext, @@ -887,19 +889,20 @@ def _create_listener_and_task_instance( **sdk_task_instance.model_dump(exclude_unset=True), task=task, _ti_context_from_server=TIRunContext( - dag_run=SdkDagRun( - dag_id="dag_id", - run_id="dag_run_run_id", - logical_date=timezone.datetime(2020, 1, 1, 1, 1, 1), - data_interval_start=None, - data_interval_end=None, - start_date=timezone.datetime(2023, 1, 1, 13, 1, 1), - end_date=timezone.datetime(2023, 1, 3, 13, 1, 1), - clear_number=0, - run_type=DagRunType.MANUAL, - run_after=timezone.datetime(2023, 1, 3, 13, 1, 1), - conf=None, - consumed_asset_events=[], + dag_run=SdkDagRun.model_validate( + { + "dag_id": "dag_id_from_dagrun_not_ti", + "run_id": "dag_run_run_id_from_dagrun_not_ti", + "logical_date": timezone.datetime(2020, 1, 1, 1, 1, 1), + "start_date": timezone.datetime(2023, 1, 1, 13, 1, 1), + "end_date": timezone.datetime(2023, 1, 3, 13, 1, 1), + "run_type": DagRunType.MANUAL, + "run_after": timezone.datetime(2023, 1, 3, 13, 1, 1), + "consumed_asset_events": [], + **( + {"state": SdkDagRunState.RUNNING} if "state" in SdkDagRun.model_fields else {} + ), + } ), task_reschedule_count=0, max_tries=1, diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 3413e98189cfc..bbf6eb4dea024 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -69,6 +69,7 @@ DRCount, ErrorResponse, OKResponse, + PreviousDagRunResult, SkipDownstreamTasks, TaskRescheduleStartDate, TICount, @@ -490,6 +491,7 @@ def get_sequence_slice( start: int | None, stop: int | None, step: int | None, + include_prior_dates: bool = False, ) -> XComSequenceSliceResponse: params = {} if start is not None: @@ -498,6 +500,8 @@ def get_sequence_slice( params["stop"] = stop if step is not None: params["step"] = step + if include_prior_dates: + params["include_prior_dates"] = include_prior_dates resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice", params=params) return XComSequenceSliceResponse.model_validate_json(resp.read()) @@ -620,6 +624,23 @@ def get_count( resp = self.client.get("dag-runs/count", params=params) return DRCount(count=resp.json()) + def get_previous( + self, + dag_id: str, + logical_date: datetime, + state: str | None = None, + ) -> PreviousDagRunResult: + """Get the previous DAG run before the given logical date, optionally filtered by state.""" + params = { + "logical_date": logical_date.isoformat(), + } + + if state: + params["state"] = state + + resp = self.client.get(f"dag-runs/{dag_id}/previous", params=params) + return PreviousDagRunResult(dag_run=resp.json()) + class BearerAuth(httpx.Auth): def __init__(self, token: str): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 1dabd8c90228d..9cdc08379ac75 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -27,7 +27,7 @@ from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, RootModel -API_VERSION: Final[str] = "2025-05-20" +API_VERSION: Final[str] = "2025-08-10" class AssetAliasReferenceAssetEventDagRun(BaseModel): @@ -494,6 +494,7 @@ class DagRun(BaseModel): end_date: Annotated[AwareDatetime | None, Field(title="End Date")] = None clear_number: Annotated[int | None, Field(title="Clear Number")] = 0 run_type: DagRunType + state: DagRunState conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None consumed_asset_events: Annotated[list[AssetEventDagRunReference], Field(title="Consumed Asset Events")] diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 5de513bda9e02..a3d6f58edfd08 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -1014,9 +1014,13 @@ def __init__( ): # Note: Metaclass handles passing in the DAG/TaskGroup from active context manager, if any - self.task_id = task_group.child_id(task_id) if task_group else task_id - if not self.__from_mapped and task_group: + # Only apply task_group prefix if this operator was not created from a mapped operator + # Mapped operators already have the prefix applied during their creation + if task_group and not self.__from_mapped: + self.task_id = task_group.child_id(task_id) task_group.add(self) + else: + self.task_id = task_id super().__init__() self.task_group = task_group diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 82df8d151ab13..7c982a050ddf4 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -17,6 +17,7 @@ from __future__ import annotations +import collections from typing import Any, Protocol import structlog @@ -30,6 +31,9 @@ XComSequenceSliceResult, ) +# Lightweight wrapper for XCom values +_XComValueWrapper = collections.namedtuple("_XComValueWrapper", "value") + log = structlog.get_logger(logger_name="task") @@ -273,6 +277,7 @@ def get_all( dag_id: str, task_id: str, run_id: str, + include_prior_dates: bool = False, ) -> Any: """ Retrieve all XCom values for a task, typically from all map indexes. @@ -287,10 +292,12 @@ def get_all( :param run_id: DAG run ID for the task. :param dag_id: DAG ID to pull XComs from. :param task_id: Task ID to pull XComs from. + :param include_prior_dates: If *False* (default), only XComs from the + specified DAG run are returned. If *True*, the latest matching XComs are + returned regardless of the run they belong to. :return: List of all XCom values if found. """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - from airflow.serialization.serde import deserialize msg = SUPERVISOR_COMMS.send( msg=GetXComSequenceSlice( @@ -301,16 +308,17 @@ def get_all( start=None, stop=None, step=None, + include_prior_dates=include_prior_dates, ), ) if not isinstance(msg, XComSequenceSliceResult): raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}") - result = deserialize(msg.root) - if not result: + if not msg.root: return None - return result + + return [cls.deserialize_value(_XComValueWrapper(value)) for value in msg.root] @staticmethod def serialize_value( diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index ec2fefa0a08b4..8934cd0e4532c 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -30,6 +30,7 @@ import methodtools +from airflow.configuration import conf from airflow.sdk.definitions._internal.mixins import DependencyMixin from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions._internal.templater import Templater @@ -50,7 +51,7 @@ TaskStateChangeCallback = Callable[[Context], None] -DEFAULT_OWNER: str = "airflow" +DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") DEFAULT_POOL_SLOTS: int = 1 DEFAULT_POOL_NAME = "default_pool" DEFAULT_PRIORITY_WEIGHT: int = 1 @@ -61,17 +62,23 @@ MINIMUM_PRIORITY_WEIGHT: int = -2147483648 MAXIMUM_PRIORITY_WEIGHT: int = 2147483647 DEFAULT_EXECUTOR: str | None = None -DEFAULT_QUEUE: str = "default" +DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False -DEFAULT_RETRIES: int = 0 -DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(seconds=300) -MAX_RETRY_DELAY: int = 24 * 60 * 60 +DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0) +DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta( + seconds=conf.getint("core", "default_task_retry_delay", fallback=300) +) +MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) # TODO: Task-SDK -- these defaults should be overridable from the Airflow config DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS -DEFAULT_WEIGHT_RULE: WeightRule = WeightRule.DOWNSTREAM -DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = None +DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( + conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) +) +DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( + "core", "default_task_execution_timeout" +) log = logging.getLogger(__name__) diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi index 30e921f2f4881..038f94a0bd8f5 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi +++ b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi @@ -496,7 +496,7 @@ class TaskDecoratorCollection: """ # [END decorator_signature] @overload - def kubernetes( + def kubernetes( # type: ignore[misc] self, *, multiple_outputs: bool | None = None, @@ -670,7 +670,7 @@ class TaskDecoratorCollection: @overload def kubernetes(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... @overload - def kubernetes_cmd( + def kubernetes_cmd( # type: ignore[misc] self, *, args_only: bool = False, # Added by _KubernetesCmdDecoratedOperator. diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index cb24a7cc6bdb7..8c84ec6bbeeb0 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -41,6 +41,7 @@ DEFAULT_WEIGHT_RULE, AbstractOperator, NotMapped, + TaskStateChangeCallback, ) from airflow.sdk.definitions._internal.expandinput import ( DictOfListsExpandInput, @@ -50,7 +51,7 @@ from airflow.sdk.definitions._internal.types import NOTSET from airflow.serialization.enums import DagAttributeTypes from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy -from airflow.typing_compat import Literal +from airflow.typing_compat import Literal, TypeGuard from airflow.utils.helpers import is_container, prevent_duplicates from airflow.utils.xcom import XCOM_RETURN_KEY @@ -60,9 +61,6 @@ import jinja2 # Slow import. import pendulum - from airflow.models.abstractoperator import ( - TaskStateChangeCallback, - ) from airflow.models.expandinput import ( OperatorExpandArgument, OperatorExpandKwargsArgument, @@ -76,14 +74,12 @@ from airflow.sdk.types import Operator from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.triggers.base import StartTriggerArgs - from airflow.typing_compat import TypeGuard from airflow.utils.context import Context from airflow.utils.operator_resources import Resources from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule - TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] - +TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] ValidationSource = Union[Literal["expand"], Literal["partial"]] diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 2c6dfea4e601c..3069490521783 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -70,6 +70,7 @@ AssetResponse, BundleInfo, ConnectionResponse, + DagRun, DagRunStateResponse, InactiveAssetsResponse, PrevSuccessfulDagRunResponse, @@ -492,6 +493,13 @@ def from_api_response(cls, dr_state_response: DagRunStateResponse) -> DagRunStat return cls(**dr_state_response.model_dump(exclude_defaults=True), type="DagRunStateResult") +class PreviousDagRunResult(BaseModel): + """Response containing previous DAG run information.""" + + dag_run: DagRun | None = None + type: Literal["PreviousDagRunResult"] = "PreviousDagRunResult" + + class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse): type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult" @@ -579,6 +587,7 @@ class SentFDs(BaseModel): XComSequenceSliceResult, InactiveAssetsResult, OKResponse, + PreviousDagRunResult, ], Field(discriminator="type"), ] @@ -683,6 +692,7 @@ class GetXComSequenceSlice(BaseModel): start: int | None stop: int | None step: int | None + include_prior_dates: bool = False type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice" @@ -775,6 +785,13 @@ class GetDagRunState(BaseModel): type: Literal["GetDagRunState"] = "GetDagRunState" +class GetPreviousDagRun(BaseModel): + dag_id: str + logical_date: AwareDatetime + state: str | None = None + type: Literal["GetPreviousDagRun"] = "GetPreviousDagRun" + + class GetAssetByName(BaseModel): name: str type: Literal["GetAssetByName"] = "GetAssetByName" @@ -853,6 +870,7 @@ class GetDRCount(BaseModel): GetDagRunState, GetDRCount, GetPrevSuccessfulDagRun, + GetPreviousDagRun, GetTaskRescheduleStartDate, GetTICount, GetTaskStates, diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 57224f66a5882..fffa1ad15553d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -83,6 +83,7 @@ GetConnection, GetDagRunState, GetDRCount, + GetPreviousDagRun, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, GetTaskStates, @@ -1131,7 +1132,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: resp = xcom elif isinstance(msg, GetXComSequenceSlice): xcoms = self.client.xcoms.get_sequence_slice( - msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.start, msg.stop, msg.step + msg.dag_id, + msg.run_id, + msg.task_id, + msg.key, + msg.start, + msg.stop, + msg.step, + msg.include_prior_dates, ) resp = XComSequenceSliceResult.from_response(xcoms) elif isinstance(msg, DeferTask): @@ -1227,6 +1235,12 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: run_ids=msg.run_ids, states=msg.states, ) + elif isinstance(msg, GetPreviousDagRun): + resp = self.client.dag_runs.get_previous( + dag_id=msg.dag_id, + logical_date=msg.logical_date, + state=msg.state, + ) elif isinstance(msg, DeleteVariable): resp = self.client.variables.delete(msg.key) elif isinstance(msg, ValidateInletsAndOutlets): diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 62aa7d37b7f7a..d4d09c8e20c48 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -44,6 +44,7 @@ from airflow.listeners.listener import get_listener_manager from airflow.sdk.api.datamodels._generated import ( AssetProfile, + DagRun, TaskInstance, TaskInstanceState, TIRunContext, @@ -65,10 +66,12 @@ ErrorResponse, GetDagRunState, GetDRCount, + GetPreviousDagRun, GetTaskRescheduleStartDate, GetTaskStates, GetTICount, InactiveAssetsResult, + PreviousDagRunResult, RescheduleTask, ResendLoggingFD, RetryTask, @@ -358,6 +361,7 @@ def xcom_pull( key=key, task_id=t_id, dag_id=dag_id, + include_prior_dates=include_prior_dates, ) if values is None: @@ -438,6 +442,30 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: return response.start_date + def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: + """Return the previous DAG run before the given logical date, optionally filtered by state.""" + context = self.get_template_context() + dag_run = context.get("dag_run") + + log = structlog.get_logger(logger_name="task") + + log.debug("Getting previous DAG run", dag_run=dag_run) + + if dag_run is None: + return None + + if dag_run.logical_date is None: + return None + + response = SUPERVISOR_COMMS.send( + msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state) + ) + + if TYPE_CHECKING: + assert isinstance(response, PreviousDagRunResult) + + return response.dag_run + @staticmethod def get_ti_count( dag_id: str, diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 8bd0ea0db8d4d..abe1f1f84a822 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -85,6 +85,8 @@ def get_template_context(self) -> Context: ... def get_first_reschedule_date(self, first_try_number) -> AwareDatetime | None: ... + def get_previous_dagrun(self, state: str | None = None) -> DagRunProtocol | None: ... + @staticmethod def get_ti_count( dag_id: str, diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index d1866f523919c..c3195a7b9bb8b 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -198,6 +198,7 @@ def __call__( def make_ti_context() -> MakeTIContextCallable: """Factory for creating TIRunContext objects.""" from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext + from airflow.utils.state import DagRunState def _make_context( dag_id: str = "test_dag", @@ -226,6 +227,7 @@ def _make_context( start_date=start_date, # type: ignore run_type=run_type, # type: ignore run_after=run_after, # type: ignore + state=DagRunState.RUNNING, conf=conf, # type: ignore consumed_asset_events=list(consumed_asset_events), ), diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index caa515de09a4d..4ffa847983546 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -19,6 +19,7 @@ import json import pickle +from datetime import datetime from unittest import mock import httpx @@ -41,6 +42,7 @@ DeferTask, ErrorResponse, OKResponse, + PreviousDagRunResult, RescheduleTask, TaskRescheduleStartDate, ) @@ -1139,6 +1141,86 @@ def handle_request(request: httpx.Request) -> httpx.Response: result = client.dag_runs.get_count(dag_id="test_dag", run_ids=["run1", "run2"]) assert result.count == 2 + def test_get_previous_basic(self): + """Test basic get_previous functionality with dag_id and logical_date.""" + logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/test_dag/previous": + assert request.url.params["logical_date"] == logical_date.isoformat() + # Return complete DagRun data + return httpx.Response( + status_code=200, + json={ + "dag_id": "test_dag", + "run_id": "prev_run", + "logical_date": "2024-01-14T12:00:00+00:00", + "start_date": "2024-01-14T12:05:00+00:00", + "run_after": "2024-01-14T12:00:00+00:00", + "run_type": "scheduled", + "state": "success", + "consumed_asset_events": [], + }, + ) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_previous(dag_id="test_dag", logical_date=logical_date) + + assert isinstance(result, PreviousDagRunResult) + assert result.dag_run.dag_id == "test_dag" + assert result.dag_run.run_id == "prev_run" + assert result.dag_run.state == "success" + + def test_get_previous_with_state_filter(self): + """Test get_previous functionality with state filtering.""" + logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/test_dag/previous": + assert request.url.params["logical_date"] == logical_date.isoformat() + assert request.url.params["state"] == "success" + # Return complete DagRun data + return httpx.Response( + status_code=200, + json={ + "dag_id": "test_dag", + "run_id": "prev_success_run", + "logical_date": "2024-01-14T12:00:00+00:00", + "start_date": "2024-01-14T12:05:00+00:00", + "run_after": "2024-01-14T12:00:00+00:00", + "run_type": "scheduled", + "state": "success", + "consumed_asset_events": [], + }, + ) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_previous(dag_id="test_dag", logical_date=logical_date, state="success") + + assert isinstance(result, PreviousDagRunResult) + assert result.dag_run.dag_id == "test_dag" + assert result.dag_run.run_id == "prev_success_run" + assert result.dag_run.state == "success" + + def test_get_previous_not_found(self): + """Test get_previous when no previous DAG run exists returns None.""" + logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/test_dag/previous": + assert request.url.params["logical_date"] == logical_date.isoformat() + # Return None (null) when no previous DAG run found + return httpx.Response(status_code=200, content="null") + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_previous(dag_id="test_dag", logical_date=logical_date) + + assert isinstance(result, PreviousDagRunResult) + assert result.dag_run is None + class TestTaskRescheduleOperations: def test_get_start_date(self): diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py index f9c075352a0d4..7a4b43c81040a 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py +++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py @@ -691,3 +691,25 @@ def group(x): ), ] ) + + +def test_mapped_operator_in_task_group_no_duplicate_prefix(): + """Test that task_id doesn't get duplicated prefix when unmapping a mapped operator in a task group.""" + from airflow.sdk.definitions.taskgroup import TaskGroup + + with DAG("test-dag"): + with TaskGroup(group_id="tg1") as tg1: + # Create a mapped task within the task group + mapped_task = MockOperator.partial(task_id="mapped_task", arg1="a").expand(arg2=["a", "b", "c"]) + + # Check the mapped operator has correct task_id + assert mapped_task.task_id == "tg1.mapped_task" + assert mapped_task.task_group == tg1 + assert mapped_task.task_group.group_id == "tg1" + + # Simulate what happens during execution - unmap the operator + # unmap expects resolved kwargs + unmapped = mapped_task.unmap({"arg2": "a"}) + + # The unmapped operator should have the same task_id, not a duplicate prefix + assert unmapped.task_id == "tg1.mapped_task", f"Expected 'tg1.mapped_task' but got '{unmapped.task_id}'" diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index 48c7ad74a1501..d0be736b4cea9 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -56,6 +56,7 @@ def test_recv_StartupDetails(self): "run_after": "2024-12-01T01:00:00Z", "end_date": None, "run_type": "manual", + "state": "success", "conf": None, "consumed_asset_events": [], }, diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 600ca4ae54972..196494c258ddf 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -51,7 +51,9 @@ AssetEventResponse, AssetProfile, AssetResponse, + DagRun, DagRunState, + DagRunType, TaskInstance, TaskInstanceState, ) @@ -75,6 +77,7 @@ GetConnection, GetDagRunState, GetDRCount, + GetPreviousDagRun, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, GetTaskStates, @@ -85,6 +88,7 @@ GetXComSequenceSlice, InactiveAssetsResult, OKResponse, + PreviousDagRunResult, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, @@ -1822,6 +1826,72 @@ def watched_subprocess(self, mocker): None, id="get_dr_count", ), + pytest.param( + GetPreviousDagRun( + dag_id="test_dag", + logical_date=timezone.parse("2024-01-15T12:00:00Z"), + ), + { + "dag_run": { + "dag_id": "test_dag", + "run_id": "prev_run", + "logical_date": timezone.parse("2024-01-14T12:00:00Z"), + "run_type": "scheduled", + "start_date": timezone.parse("2024-01-15T12:00:00Z"), + "run_after": timezone.parse("2024-01-15T12:00:00Z"), + "consumed_asset_events": [], + "state": "success", + "data_interval_start": None, + "data_interval_end": None, + "end_date": None, + "clear_number": 0, + "conf": None, + }, + "type": "PreviousDagRunResult", + }, + "dag_runs.get_previous", + (), + { + "dag_id": "test_dag", + "logical_date": timezone.parse("2024-01-15T12:00:00Z"), + "state": None, + }, + PreviousDagRunResult( + dag_run=DagRun( + dag_id="test_dag", + run_id="prev_run", + logical_date=timezone.parse("2024-01-14T12:00:00Z"), + run_type=DagRunType.SCHEDULED, + start_date=timezone.parse("2024-01-15T12:00:00Z"), + run_after=timezone.parse("2024-01-15T12:00:00Z"), + consumed_asset_events=[], + state=DagRunState.SUCCESS, + ) + ), + None, + id="get_previous_dagrun", + ), + pytest.param( + GetPreviousDagRun( + dag_id="test_dag", + logical_date=timezone.parse("2024-01-15T12:00:00Z"), + state="success", + ), + { + "dag_run": None, + "type": "PreviousDagRunResult", + }, + "dag_runs.get_previous", + (), + { + "dag_id": "test_dag", + "logical_date": timezone.parse("2024-01-15T12:00:00Z"), + "state": "success", + }, + PreviousDagRunResult(dag_run=None), + None, + id="get_previous_dagrun_with_state", + ), pytest.param( GetTaskStates(dag_id="test_dag", task_group_id="test_group"), { @@ -1884,10 +1954,11 @@ def watched_subprocess(self, mocker): start=None, stop=None, step=None, + include_prior_dates=False, ), {"root": ["foo", "bar"], "type": "XComSequenceSliceResult"}, "xcoms.get_sequence_slice", - ("test_dag", "test_run", "test_task", "test_key", None, None, None), + ("test_dag", "test_run", "test_task", "test_key", None, None, None, False), {}, XComSequenceSliceResult(root=["foo", "bar"]), None, diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 362b89e3ad0a5..2556845dac897 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -50,6 +50,7 @@ from airflow.sdk.api.datamodels._generated import ( AssetProfile, AssetResponse, + DagRun, DagRunState, TaskInstance, TaskInstanceState, @@ -71,12 +72,14 @@ GetConnection, GetDagRunState, GetDRCount, + GetPreviousDagRun, GetTaskStates, GetTICount, GetVariable, GetXCom, GetXComSequenceSlice, OKResponse, + PreviousDagRunResult, PrevSuccessfulDagRunResult, SetRenderedFields, SetXCom, @@ -1802,6 +1805,66 @@ def test_get_task_states(self, mock_supervisor_comms): ) assert states == {"run1": {"task1": "running"}} + def test_get_previous_dagrun_basic(self, create_runtime_ti, mock_supervisor_comms): + """Test that get_previous_dagrun sends the correct request without state filter.""" + + task = BaseOperator(task_id="hello") + dag_id = "test_dag" + runtime_ti = create_runtime_ti(task=task, dag_id=dag_id, logical_date=timezone.datetime(2025, 1, 2)) + + dag_run_data = DagRun( + dag_id=dag_id, + run_id="prev_run", + logical_date=timezone.datetime(2025, 1, 1), + start_date=timezone.datetime(2025, 1, 1), + run_after=timezone.datetime(2025, 1, 1), + run_type="scheduled", + state="success", + consumed_asset_events=[], + ) + + mock_supervisor_comms.send.return_value = PreviousDagRunResult(dag_run=dag_run_data) + + dr = runtime_ti.get_previous_dagrun() + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetPreviousDagRun(dag_id="test_dag", logical_date=timezone.datetime(2025, 1, 2), state=None), + ) + assert dr.dag_id == "test_dag" + assert dr.run_id == "prev_run" + assert dr.state == "success" + + def test_get_previous_dagrun_with_state(self, create_runtime_ti, mock_supervisor_comms): + """Test that get_previous_dagrun sends the correct request with state filter.""" + + task = BaseOperator(task_id="hello") + dag_id = "test_dag" + runtime_ti = create_runtime_ti(task=task, dag_id=dag_id, logical_date=timezone.datetime(2025, 1, 2)) + + dag_run_data = DagRun( + dag_id=dag_id, + run_id="prev_success_run", + logical_date=timezone.datetime(2025, 1, 1), + start_date=timezone.datetime(2025, 1, 1), + run_after=timezone.datetime(2025, 1, 1), + run_type="scheduled", + state="success", + consumed_asset_events=[], + ) + + mock_supervisor_comms.send.return_value = PreviousDagRunResult(dag_run=dag_run_data) + + dr = runtime_ti.get_previous_dagrun(state="success") + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetPreviousDagRun( + dag_id="test_dag", logical_date=timezone.datetime(2025, 1, 2), state="success" + ), + ) + assert dr.dag_id == "test_dag" + assert dr.run_id == "prev_success_run" + assert dr.state == "success" + class TestXComAfterTaskExecution: @pytest.mark.parametrize( @@ -1986,8 +2049,7 @@ def test_xcom_pull_from_custom_xcom_backend( class CustomOperator(BaseOperator): def execute(self, context): - value = context["ti"].xcom_pull(task_ids="pull_task", key="key") - print(f"Pulled XCom Value: {value}") + context["ti"].xcom_pull(task_ids="pull_task", key="key") task = CustomOperator(task_id="pull_task") runtime_ti = create_runtime_ti(task=task) @@ -1998,6 +2060,7 @@ def execute(self, context): dag_id="test_dag", task_id="pull_task", run_id="test_run", + include_prior_dates=False, ) assert not any( @@ -2014,6 +2077,81 @@ def execute(self, context): for x in mock_supervisor_comms.send.call_args_list ) + def test_get_all_uses_custom_deserialize_value(self, mock_supervisor_comms): + """ + Tests that XCom.get_all() calls the custom deserialize_value method. + """ + + class CustomXCom(BaseXCom): + @classmethod + def deserialize_value(cls, result): + """Custom deserialization that adds a prefix to show it was called.""" + original_value = super().deserialize_value(result) + return f"from custom xcom deserialize:{original_value}" + + serialized_values = ["value1", "value2", "value3"] + mock_supervisor_comms.send.return_value = XComSequenceSliceResult(root=serialized_values) + + result = CustomXCom.get_all(key="test_key", dag_id="test_dag", task_id="test_task", run_id="test_run") + + expected = [ + "from custom xcom deserialize:value1", + "from custom xcom deserialize:value2", + "from custom xcom deserialize:value3", + ] + assert result == expected + + @pytest.mark.parametrize( + ("include_prior_dates", "expected_value"), + [ + pytest.param(True, True, id="include_prior_dates_true"), + pytest.param(False, False, id="include_prior_dates_false"), + pytest.param(None, False, id="include_prior_dates_default"), + ], + ) + def test_xcom_pull_with_include_prior_dates( + self, + create_runtime_ti, + mock_supervisor_comms, + include_prior_dates, + expected_value, + ): + """Test that xcom_pull with include_prior_dates parameter correctly behaves as we expect.""" + task = BaseOperator(task_id="pull_task") + runtime_ti = create_runtime_ti(task=task) + + value = {"previous_run_data": "test_value"} + ser_value = BaseXCom.serialize_value(value) + + def mock_send_side_effect(*args, **kwargs): + msg = kwargs.get("msg") or args[0] + if isinstance(msg, GetXComSequenceSlice): + assert msg.include_prior_dates is expected_value, ( + f"include_prior_dates should be {expected_value} in GetXComSequenceSlice" + ) + return XComSequenceSliceResult(root=[ser_value]) + return XComResult(key="test_key", value=None) + + mock_supervisor_comms.send.side_effect = mock_send_side_effect + kwargs = {"key": "test_key", "task_ids": "previous_task"} + if include_prior_dates is not None: + kwargs["include_prior_dates"] = include_prior_dates + result = runtime_ti.xcom_pull(**kwargs) + assert result == value + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXComSequenceSlice( + key="test_key", + dag_id=runtime_ti.dag_id, + run_id=runtime_ti.run_id, + task_id="previous_task", + start=None, + stop=None, + step=None, + include_prior_dates=expected_value, + ), + ) + class TestDagParamRuntime: DEFAULT_ARGS = {