From d027d4d7913556b0195d2984d1626eec4cb1e6d8 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 26 Jun 2025 18:53:20 +0530 Subject: [PATCH 1/3] Replace `models.BaseOperator` to Task SDK one for Standard Provider The Providers should use the BaseOperator from Task SDK for Airflow 3.0+. --- airflow-core/tests/unit/core/test_core.py | 24 ++-- .../tests/unit/models/test_cleartasks.py | 9 +- .../tests/unit/models/test_taskinstance.py | 68 +++++---- .../src/tests_common/pytest_plugin.py | 35 +++++ .../tests/system/amazon/aws/utils/k8s.py | 7 +- .../apache/spark/decorators/test_pyspark.py | 16 +-- .../operators/test_campaign_manager.py | 47 +++--- .../operators/test_display_video.py | 43 +++--- .../example_openlineage_base_complex_dag.py | 5 +- .../unit/openlineage/extractors/test_base.py | 8 +- .../unit/openlineage/plugins/test_listener.py | 8 +- .../unit/openlineage/utils/test_utils.py | 5 +- .../snowflake/decorators/test_snowpark.py | 22 ++- .../unit/snowflake/operators/test_snowpark.py | 18 ++- .../providers/standard/operators/bash.py | 3 +- .../providers/standard/operators/branch.py | 3 +- .../providers/standard/operators/empty.py | 7 +- .../providers/standard/operators/python.py | 3 +- .../providers/standard/operators/smooth.py | 7 +- .../standard/sensors/external_task.py | 5 +- .../providers/standard/utils/skipmixin.py | 15 +- .../decorators/test_branch_external_python.py | 14 +- .../standard/decorators/test_branch_python.py | 14 +- .../decorators/test_branch_virtualenv.py | 14 +- .../decorators/test_external_python.py | 72 +++++----- .../unit/standard/decorators/test_python.py | 113 +++++---------- .../decorators/test_python_virtualenv.py | 100 ++++++------- .../standard/decorators/test_short_circuit.py | 12 +- .../unit/standard/operators/test_bash.py | 12 +- .../operators/test_branch_operator.py | 41 +++--- .../unit/standard/operators/test_datetime.py | 29 ++-- .../operators/test_latest_only_operator.py | 25 ++-- .../unit/standard/operators/test_python.py | 134 +++++++++--------- .../unit/standard/operators/test_weekday.py | 14 +- .../sensors/test_external_task_sensor.py | 7 +- task-sdk/src/airflow/sdk/bases/decorator.py | 5 +- 36 files changed, 495 insertions(+), 469 deletions(-) diff --git a/airflow-core/tests/unit/core/test_core.py b/airflow-core/tests/unit/core/test_core.py index d0022ed33d8b6..cfdb9dee32357 100644 --- a/airflow-core/tests/unit/core/test_core.py +++ b/airflow-core/tests/unit/core/test_core.py @@ -24,9 +24,7 @@ import pytest from airflow.exceptions import AirflowTaskTimeout -from airflow.models import TaskInstance from airflow.models.baseoperator import BaseOperator -from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.utils.timezone import datetime @@ -49,8 +47,15 @@ def teardown_method(self): self.clean_db() def test_dryrun(self, dag_maker): + class TemplateFieldOperator(BaseOperator): + template_fields = ["bash_command"] + + def __init__(self, bash_command, **kwargs): + self.bash_command = bash_command + super().__init__(**kwargs) + with dag_maker(): - op = BashOperator(task_id="test_dryrun", bash_command="echo success") + op = TemplateFieldOperator(task_id="test_dryrun", bash_command="echo success") dag_maker.create_dagrun() op.dry_run() @@ -81,9 +86,8 @@ def sleep_and_catch_other_exceptions(): execution_timeout=timedelta(seconds=1), python_callable=sleep_and_catch_other_exceptions, ) - dag_maker.create_dagrun() with pytest.raises(AirflowTaskTimeout): - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + dag_maker.run_ti(op.task_id) def test_dag_params_and_task_params(self, dag_maker): # This test case guards how params of DAG and Operator work together. @@ -91,7 +95,6 @@ def test_dag_params_and_task_params(self, dag_maker): # it is guaranteed to be available eventually. # - If any key exists in both DAG's params and Operator's params, # the latter has precedence. - TI = TaskInstance with dag_maker( schedule=timedelta(weeks=1), @@ -106,12 +109,9 @@ def test_dag_params_and_task_params(self, dag_maker): dr = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, ) - task1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - task2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - ti1 = TI(task=task1, run_id=dr.run_id) - ti2 = TI(task=task2, run_id=dr.run_id) - ti1.refresh_from_db() - ti2.refresh_from_db() + ti1 = dag_maker.run_ti(task1.task_id, dr) + ti2 = dag_maker.run_ti(task2.task_id, dr) + context1 = ti1.get_template_context() context2 = ti2.get_template_context() diff --git a/airflow-core/tests/unit/models/test_cleartasks.py b/airflow-core/tests/unit/models/test_cleartasks.py index 9ece88246196a..002165ff71ec2 100644 --- a/airflow-core/tests/unit/models/test_cleartasks.py +++ b/airflow-core/tests/unit/models/test_cleartasks.py @@ -23,6 +23,7 @@ import pytest from sqlalchemy import select +from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance, TaskInstance as TI, clear_task_instances @@ -677,9 +678,13 @@ def _get_ti(old_ti): assert ti.max_tries == 1 def test_operator_clear(self, dag_maker, session): + class ClearOperator(BaseOperator): + def execute(self, context): + pass + with dag_maker("test_operator_clear"): - op1 = EmptyOperator(task_id="test1") - op2 = EmptyOperator(task_id="test2", retries=1) + op1 = ClearOperator(task_id="test1") + op2 = ClearOperator(task_id="test2", retries=1) op1 >> op2 dr = dag_maker.create_dagrun( diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 5c6e4f69a5256..8f0de3d07c95e 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -41,6 +41,7 @@ AirflowSkipException, ) from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel +from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection from airflow.models.dag import DAG from airflow.models.dagrun import DagRun @@ -775,8 +776,11 @@ def run_ti_and_assert( assert not task_reschedules_for_ti(ti) def test_depends_on_past_catchup_true(self, dag_maker): + class CustomOp(BaseOperator): + def execute(self, context): ... + with dag_maker(dag_id="test_depends_on_past", serialized=True, catchup=True): - task = EmptyOperator( + task = CustomOp( task_id="test_dop_task", depends_on_past=True, ) @@ -806,8 +810,11 @@ def test_depends_on_past_catchup_true(self, dag_maker): assert ti.state == State.SUCCESS def test_depends_on_past_catchup_false(self, dag_maker): + class CustomOp(BaseOperator): + def execute(self, context): ... + with dag_maker(dag_id="test_depends_on_past_catchup_false", serialized=True, catchup=False): - task = EmptyOperator( + task = CustomOp( task_id="test_dop_task", depends_on_past=True, ) @@ -2281,7 +2288,10 @@ def test_template_with_json_variable_missing(self, create_task_instance, session ti.task.render_template('{{ var.json.get("missing_variable") }}', context) @provide_session - def test_handle_failure(self, create_dummy_dag, session=None): + def test_handle_failure(self, dag_maker, session=None): + class CustomOp(BaseOperator): + def execute(self, context): ... + start_date = timezone.datetime(2016, 6, 1) clear_db_runs() @@ -2298,16 +2308,15 @@ def test_handle_failure(self, create_dummy_dag, session=None): __name__="mock_on_retry_1", __call__=mock.MagicMock(), ) - dag, task1 = create_dummy_dag( - dag_id="test_handle_failure", - schedule=None, - start_date=start_date, - task_id="test_handle_failure_on_failure", - with_dagrun_type=DagRunType.MANUAL, - on_failure_callback=mock_on_failure_1, - on_retry_callback=mock_on_retry_1, - session=session, - ) + with dag_maker(dag_id="test_handle_failure", start_date=start_date, schedule=None) as dag: + task1 = CustomOp( + task_id="test_handle_failure_on_failure", + on_failure_callback=mock_on_failure_1, + on_retry_callback=mock_on_retry_1, + ) + + dag_maker.create_dagrun(session=session, run_type=DagRunType.MANUAL, start_date=start_date) + logical_date = timezone.utcnow() dr = dag.create_dagrun( run_id="test2", @@ -2347,7 +2356,7 @@ def test_handle_failure(self, create_dummy_dag, session=None): __name__="mock_on_retry_2", __call__=mock.MagicMock(), ) - task2 = EmptyOperator( + task2 = CustomOp( task_id="test_handle_failure_on_retry", on_failure_callback=mock_on_failure_2, on_retry_callback=mock_on_retry_2, @@ -2375,7 +2384,7 @@ def test_handle_failure(self, create_dummy_dag, session=None): __name__="mock_on_retry_3", __call__=mock.MagicMock(), ) - task3 = EmptyOperator( + task3 = CustomOp( task_id="test_handle_failure_on_force_fail", on_failure_callback=mock_on_failure_3, on_retry_callback=mock_on_retry_3, @@ -2450,20 +2459,23 @@ def test_handle_failure_task_undefined(self, create_task_instance): del ti.task ti.handle_failure("test ti.task undefined") - def test_handle_failure_fail_fast(self, create_dummy_dag, session): + def test_handle_failure_fail_fast(self, dag_maker, session): start_date = timezone.datetime(2016, 6, 1) clear_db_runs() - dag, task1 = create_dummy_dag( + class CustomOp(BaseOperator): + def execute(self, context): ... + + with dag_maker( dag_id="test_handle_failure_fail_fast", - schedule=None, start_date=start_date, - task_id="task1", - trigger_rule="all_success", - with_dagrun_type=DagRunType.MANUAL, - session=session, + schedule=None, fail_fast=True, - ) + ) as dag: + task1 = CustomOp(task_id="task1", trigger_rule="all_success") + + dag_maker.create_dagrun(run_type=DagRunType.MANUAL, start_date=start_date) + logical_date = timezone.utcnow() dr = dag.create_dagrun( run_id="test_ff", @@ -2484,19 +2496,13 @@ def test_handle_failure_fail_fast(self, create_dummy_dag, session): states = [State.RUNNING, State.FAILED, State.QUEUED, State.SCHEDULED, State.DEFERRED] tasks = [] for i, state in enumerate(states): - op = EmptyOperator( - task_id=f"reg_Task{i}", - dag=dag, - ) + op = CustomOp(task_id=f"reg_Task{i}", dag=dag) ti = TI(task=op, run_id=dr.run_id) ti.state = state session.add(ti) tasks.append(ti) - fail_task = EmptyOperator( - task_id="fail_Task", - dag=dag, - ) + fail_task = CustomOp(task_id="fail_Task", dag=dag) ti_ff = TI(task=fail_task, run_id=dr.run_id) ti_ff.state = State.FAILED session.add(ti_ff) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 242723bad583e..4a6f86f7ac66c 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -806,6 +806,14 @@ def create_dagrun( def create_dagrun_after(self, dagrun: DagRun, **kwargs) -> DagRun: ... + def run_ti( + self, + task_id: str, + dag_run: DagRun | None = ..., + dag_run_kwargs: dict | None = ..., + **kwargs, + ) -> TaskInstance: ... + def __call__( self, dag_id: str = "test_dag", @@ -1100,6 +1108,33 @@ def create_dagrun_after(self, dagrun, **kwargs): **kwargs, ) + def run_ti(self, task_id, dag_run=None, dag_run_kwargs=None, **kwargs): + """ + Create a dagrun and run a specific task instance with proper task refresh. + + This is a convenience method for running a single task instance: + 1. Create a dagrun if it does not exist + 2. Get the specific task instance by task_id + 3. Refresh the task instance from the DAG task + 4. Run the task instance + + Returns the created TaskInstance. + """ + if dag_run is None: + if dag_run_kwargs is None: + dag_run_kwargs = {} + dag_run = self.create_dagrun(**dag_run_kwargs) + ti = dag_run.get_task_instance(task_id=task_id) + if ti is None: + available_task_ids = [task.task_id for task in self.dag.tasks] + raise ValueError( + f"Task instance with task_id '{task_id}' not found in dag run. " + f"Available task_ids: {available_task_ids}" + ) + ti.refresh_from_task(self.dag.get_task(ti.task_id)) + ti.run(**kwargs) + return ti + def sync_dagbag_to_db(self): if not AIRFLOW_V_3_0_PLUS: self.dagbag.sync_to_db() diff --git a/providers/amazon/tests/system/amazon/aws/utils/k8s.py b/providers/amazon/tests/system/amazon/aws/utils/k8s.py index a882d9e42842e..b2c82ae82c674 100644 --- a/providers/amazon/tests/system/amazon/aws/utils/k8s.py +++ b/providers/amazon/tests/system/amazon/aws/utils/k8s.py @@ -16,15 +16,10 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.standard.operators.bash import BashOperator -if TYPE_CHECKING: - from airflow.models.operator import Operator - -def get_describe_pod_operator(cluster_name: str, pod_name: str) -> Operator: +def get_describe_pod_operator(cluster_name: str, pod_name: str) -> BashOperator: """Returns an operator that'll print the output of a `k describe pod` in the airflow logs.""" return BashOperator( task_id="describe_pod", diff --git a/providers/apache/spark/tests/unit/apache/spark/decorators/test_pyspark.py b/providers/apache/spark/tests/unit/apache/spark/decorators/test_pyspark.py index 4cb7816c92cde..cb37e528d43ca 100644 --- a/providers/apache/spark/tests/unit/apache/spark/decorators/test_pyspark.py +++ b/providers/apache/spark/tests/unit/apache/spark/decorators/test_pyspark.py @@ -103,11 +103,11 @@ def f(spark, sc): return [random.random() for _ in range(100)] with dag_maker(): - ret = f() + f() dr = dag_maker.create_dagrun() - ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date) ti = dr.get_task_instances()[0] + ti.run() assert len(ti.xcom_pull()) == 100 assert config.get("spark.master") == "spark://none" assert config.get("spark.executor.memory") == "2g" @@ -130,11 +130,11 @@ def f(): return e with dag_maker(): - ret = f() + f() dr = dag_maker.create_dagrun() - ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date) ti = dr.get_task_instances()[0] + ti.run() assert ti.xcom_pull() == e assert config.get("spark.master") == "local[*]" spark_mock.builder.config.assert_called_once_with(conf=conf_mock()) @@ -154,11 +154,11 @@ def f(spark, sc): return True with dag_maker(): - ret = f() + f() dr = dag_maker.create_dagrun() - ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date) ti = dr.get_task_instances()[0] + ti.run() assert ti.xcom_pull() assert config.get("spark.remote") == "sc://localhost" assert config.get("spark.master") is None @@ -180,11 +180,11 @@ def f(spark, sc): return True with dag_maker(): - ret = f() + f() dr = dag_maker.create_dagrun() - ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date) ti = dr.get_task_instances()[0] + ti.run() assert ti.xcom_pull() assert config.get("spark.remote") == "sc://localhost/;user_id=connect;token=1234;use_ssl=True" assert config.get("spark.master") is None diff --git a/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py b/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py index f5c03c85eb5dd..cb2af61d9d297 100644 --- a/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py +++ b/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py @@ -23,7 +23,7 @@ import pytest -from airflow.models import DAG, TaskInstance as TI +from airflow.models import TaskInstance as TI from airflow.providers.google.marketing_platform.operators.campaign_manager import ( GoogleCampaignManagerBatchInsertConversionsOperator, GoogleCampaignManagerBatchUpdateConversionsOperator, @@ -153,7 +153,7 @@ def test_execute( @pytest.mark.parametrize( "test_bucket_name", - [BUCKET_NAME, f"gs://{BUCKET_NAME}", "XComArg", "{{ ti.xcom_pull(task_ids='f') }}"], + [BUCKET_NAME, f"gs://{BUCKET_NAME}", "XComArg", "{{ ti.xcom_pull(task_ids='taskflow_op') }}"], ) @mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.http") @mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.tempfile") @@ -168,6 +168,7 @@ def test_set_bucket_name( tempfile_mock, http_mock, test_bucket_name, + dag_maker, ): http_mock.MediaIoBaseDownload.return_value.next_chunk.return_value = ( None, @@ -175,33 +176,29 @@ def test_set_bucket_name( ) tempfile_mock.NamedTemporaryFile.return_value.__enter__.return_value.name = TEMP_FILE_NAME - dag = DAG( - dag_id="test_set_bucket_name", - start_date=DEFAULT_DATE, - schedule=None, - catchup=False, - ) + with dag_maker(dag_id="test_set_bucket_name", start_date=DEFAULT_DATE) as dag: + if BUCKET_NAME not in test_bucket_name: - if BUCKET_NAME not in test_bucket_name: + @dag.task(task_id="taskflow_op") + def f(): + return BUCKET_NAME - @dag.task - def f(): - return BUCKET_NAME + taskflow_op = f() - taskflow_op = f() - taskflow_op.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + GoogleCampaignManagerDownloadReportOperator( + profile_id=PROFILE_ID, + report_id=REPORT_ID, + file_id=FILE_ID, + bucket_name=test_bucket_name if test_bucket_name != "XComArg" else taskflow_op, + report_name=REPORT_NAME, + api_version=API_VERSION, + task_id="test_task", + ) - op = GoogleCampaignManagerDownloadReportOperator( - profile_id=PROFILE_ID, - report_id=REPORT_ID, - file_id=FILE_ID, - bucket_name=test_bucket_name if test_bucket_name != "XComArg" else taskflow_op, - report_name=REPORT_NAME, - api_version=API_VERSION, - task_id="test_task", - dag=dag, - ) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dr = dag_maker.create_dagrun() + + for ti in dr.get_task_instances(): + ti.run() gcs_hook_mock.return_value.upload.assert_called_once_with( bucket_name=BUCKET_NAME, diff --git a/providers/google/tests/unit/google/marketing_platform/operators/test_display_video.py b/providers/google/tests/unit/google/marketing_platform/operators/test_display_video.py index 79f15a2cdae8f..8596743dd7d6b 100644 --- a/providers/google/tests/unit/google/marketing_platform/operators/test_display_video.py +++ b/providers/google/tests/unit/google/marketing_platform/operators/test_display_video.py @@ -24,7 +24,7 @@ import pytest from airflow.exceptions import AirflowException -from airflow.models import DAG, TaskInstance as TI +from airflow.models import TaskInstance as TI from airflow.providers.google.marketing_platform.operators.display_video import ( GoogleDisplayVideo360CreateQueryOperator, GoogleDisplayVideo360CreateSDFDownloadTaskOperator, @@ -143,7 +143,7 @@ def test_execute( @pytest.mark.parametrize( "test_bucket_name", - [BUCKET_NAME, f"gs://{BUCKET_NAME}", "XComArg", "{{ ti.xcom_pull(task_ids='f') }}"], + [BUCKET_NAME, f"gs://{BUCKET_NAME}", "XComArg", "{{ ti.xcom_pull(task_ids='taskflow_op') }}"], ) @mock.patch("airflow.providers.google.marketing_platform.operators.display_video.shutil") @mock.patch("airflow.providers.google.marketing_platform.operators.display_video.urllib.request") @@ -160,37 +160,34 @@ def test_set_bucket_name( mock_request, mock_shutil, test_bucket_name, + dag_maker, ): mock_temp.NamedTemporaryFile.return_value.__enter__.return_value.name = FILENAME mock_hook.return_value.get_report.return_value = { "metadata": {"status": {"state": "DONE"}, "googleCloudStoragePath": "TEST"} } - dag = DAG( - dag_id="test_set_bucket_name", - start_date=DEFAULT_DATE, - schedule=None, - catchup=False, - ) + with dag_maker(dag_id="test_set_bucket_name", start_date=DEFAULT_DATE) as dag: + if BUCKET_NAME not in test_bucket_name: - if BUCKET_NAME not in test_bucket_name: + @dag.task(task_id="taskflow_op") + def f(): + return BUCKET_NAME - @dag.task - def f(): - return BUCKET_NAME + taskflow_op = f() - taskflow_op = f() - taskflow_op.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + GoogleDisplayVideo360DownloadReportV2Operator( + query_id=QUERY_ID, + report_id=REPORT_ID, + bucket_name=test_bucket_name if test_bucket_name != "XComArg" else taskflow_op, + report_name=REPORT_NAME, + task_id="test_task", + ) - op = GoogleDisplayVideo360DownloadReportV2Operator( - query_id=QUERY_ID, - report_id=REPORT_ID, - bucket_name=test_bucket_name if test_bucket_name != "XComArg" else taskflow_op, - report_name=REPORT_NAME, - task_id="test_task", - dag=dag, - ) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dr = dag_maker.create_dagrun() + + for ti in dr.get_task_instances(): + ti.run() mock_gcs_hook.return_value.upload.assert_called_once_with( bucket_name=BUCKET_NAME, diff --git a/providers/openlineage/tests/system/openlineage/example_openlineage_base_complex_dag.py b/providers/openlineage/tests/system/openlineage/example_openlineage_base_complex_dag.py index 43ddaab341a6f..fd97a4040a656 100644 --- a/providers/openlineage/tests/system/openlineage/example_openlineage_base_complex_dag.py +++ b/providers/openlineage/tests/system/openlineage/example_openlineage_base_complex_dag.py @@ -31,6 +31,7 @@ from __future__ import annotations from datetime import datetime, timedelta +from typing import Any from airflow import DAG from airflow.models import Variable @@ -106,7 +107,9 @@ def __init__(self, **kwargs): task_5 = PythonOperator(task_id="task_5", python_callable=lambda: 1) with TaskGroup("section_2", parent_group=tg, tooltip="group_tooltip") as tg2: if AIRFLOW_VERSION.major == 3: - add_args = {"run_as_user": "some_user"} # Random user break task execution on AF2 + add_args: dict[str, Any] = { + "run_as_user": "some_user" + } # Random user break task execution on AF2 else: add_args = {"sla": timedelta(seconds=123)} # type: ignore[dict-item] # SLA is not present in AF3 yet task_6 = EmptyOperator(task_id="task_6", on_success_callback=lambda x: print(1), **add_args) diff --git a/providers/openlineage/tests/unit/openlineage/extractors/test_base.py b/providers/openlineage/tests/unit/openlineage/extractors/test_base.py index d14933ca6e822..e2cf9e7e76301 100644 --- a/providers/openlineage/tests/unit/openlineage/extractors/test_base.py +++ b/providers/openlineage/tests/unit/openlineage/extractors/test_base.py @@ -24,7 +24,13 @@ from openlineage.client.event_v2 import Dataset from openlineage.client.facet_v2 import BaseFacet, JobFacet, parent_run, sql_job -from airflow.models.baseoperator import BaseOperator +from airflow.providers.openlineage.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator +else: + from airflow.models.baseoperator import BaseOperator + from airflow.models.taskinstance import TaskInstanceState from airflow.providers.openlineage.extractors.base import ( BaseExtractor, diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py index 956b88e96dde8..0b686727a9a59 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py @@ -32,7 +32,13 @@ from uuid6 import uuid7 from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.baseoperator import BaseOperator +from airflow.providers.openlineage.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator +else: + from airflow.models.baseoperator import BaseOperator + from airflow.providers.openlineage.extractors.base import OperatorLineage from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter from airflow.providers.openlineage.plugins.listener import OpenLineageListener diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py index 9130098e0bc50..753c6b166590b 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py @@ -30,10 +30,11 @@ from airflow.utils import timezone if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import task + from airflow.sdk import BaseOperator, task else: from airflow.decorators import task -from airflow.models.baseoperator import BaseOperator + from airflow.models.baseoperator import BaseOperator + from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance, TaskInstanceState from airflow.providers.common.compat.assets import Asset diff --git a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py index f51b5d6a6acca..1a3d39d0d1b3f 100644 --- a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py @@ -69,12 +69,11 @@ def func2(): return number with dag_maker(dag_id=TEST_DAG_ID): - rets = [func1(), func2()] + _ = [func1(), func2()] dr = dag_maker.create_dagrun() - for ret in rets: - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for ti in dr.get_task_instances(): + ti.run() assert ti.xcom_pull() == number assert mock_snowflake_hook.call_count == 2 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 2 @@ -122,12 +121,11 @@ def func3(number: int): return number with dag_maker(dag_id=TEST_DAG_ID): - rets = [func1(number=number), func2(number=number), func3(number=number)] + _ = [func1(number=number), func2(number=number), func3(number=number)] dr = dag_maker.create_dagrun() - for ret in rets: - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for ti in dr.get_task_instances(): + ti.run() assert ti.xcom_pull() == number assert mock_snowflake_hook.call_count == 3 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 3 @@ -147,11 +145,11 @@ def func(session: Session): assert session == mock_snowflake_hook.return_value.get_snowpark_session.return_value with dag_maker(dag_id=TEST_DAG_ID): - ret = func() + func() dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for ti in dr.get_task_instances(): + ti.run() assert ti.xcom_pull() is None mock_snowflake_hook.assert_called_once() mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once() @@ -181,11 +179,11 @@ def func(session: Session): assert run_task.xcom.get(key="return_value") == {"a": 1, "b": "2"} else: with dag_maker(dag_id=TEST_DAG_ID): - ret = func() + func() dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) ti = dr.get_task_instances()[0] + ti.run() assert ti.xcom_pull(key="a") == 1 assert ti.xcom_pull(key="b") == "2" assert ti.xcom_pull() == {"a": 1, "b": "2"} @@ -216,11 +214,11 @@ def func(session: Session): return session.query_tag with dag_maker(dag_id=TEST_DAG_ID): - ret = func() + func() dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) ti = dr.get_task_instances()[0] + ti.run() query_tag = ti.xcom_pull() assert query_tag == { "dag_id": TEST_DAG_ID, diff --git a/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py index b39bf3c105cab..2d44f5d44bc84 100644 --- a/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py @@ -51,7 +51,7 @@ def func1(session: Session): def func2(): return number - operators = [ + _ = [ SnowparkOperator( task_id=f"{TASK_ID}_{i}", snowflake_conn_id=CONN_ID, @@ -67,9 +67,8 @@ def func2(): ] dr = dag_maker.create_dagrun() - for operator in operators: - operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for ti in dr.get_task_instances(): + ti.run() assert ti.xcom_pull() == number assert mock_snowflake_hook.call_count == 2 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 2 @@ -91,7 +90,7 @@ def func2(number: int, session: Session): def func3(number: int): return number - operators = [ + _ = [ SnowparkOperator( task_id=f"{TASK_ID}_{i}", snowflake_conn_id=CONN_ID, @@ -108,9 +107,8 @@ def func3(number: int): ] dr = dag_maker.create_dagrun() - for operator in operators: - operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for ti in dr.get_task_instances(): + ti.run() assert ti.xcom_pull() == number assert mock_snowflake_hook.call_count == 3 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 3 @@ -122,7 +120,7 @@ def test_snowpark_operator_no_return(self, mock_snowflake_hook, dag_maker): def func(session: Session): assert session == mock_snowflake_hook.return_value.get_snowpark_session.return_value - operator = SnowparkOperator( + SnowparkOperator( task_id=TASK_ID, snowflake_conn_id=CONN_ID, python_callable=func, @@ -135,8 +133,8 @@ def func(session: Session): ) dr = dag_maker.create_dagrun() - operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for ti in dr.get_task_instances(): + ti.run() assert ti.xcom_pull() is None mock_snowflake_hook.assert_called_once() mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once() @@ -157,7 +155,7 @@ def update_query_tag(new_tags): def func(session: Session): return session.query_tag - operator = SnowparkOperator( + SnowparkOperator( task_id=TASK_ID, snowflake_conn_id=CONN_ID, python_callable=func, @@ -170,8 +168,8 @@ def func(session: Session): ) dr = dag_maker.create_dagrun() - operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) ti = dr.get_task_instances()[0] + ti.run() query_tag = ti.xcom_pull() assert query_tag == { "dag_id": TEST_DAG_ID, diff --git a/providers/standard/src/airflow/providers/standard/operators/bash.py b/providers/standard/src/airflow/providers/standard/operators/bash.py index 84db1e081e7f2..04a2bfedf3be3 100644 --- a/providers/standard/src/airflow/providers/standard/operators/bash.py +++ b/providers/standard/src/airflow/providers/standard/operators/bash.py @@ -25,13 +25,14 @@ from typing import TYPE_CHECKING, Any, Callable, cast from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.models.baseoperator import BaseOperator from airflow.providers.standard.hooks.subprocess import SubprocessHook, SubprocessResult, working_directory from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator from airflow.sdk.execution_time.context import context_to_airflow_vars else: + from airflow.models.baseoperator import BaseOperator from airflow.utils.operator_helpers import context_to_airflow_vars # type: ignore[no-redef, attr-defined] if TYPE_CHECKING: diff --git a/providers/standard/src/airflow/providers/standard/operators/branch.py b/providers/standard/src/airflow/providers/standard/operators/branch.py index cc50d1e25d6fa..3db925ace82aa 100644 --- a/providers/standard/src/airflow/providers/standard/operators/branch.py +++ b/providers/standard/src/airflow/providers/standard/operators/branch.py @@ -22,12 +22,13 @@ from collections.abc import Iterable from typing import TYPE_CHECKING -from airflow.models.baseoperator import BaseOperator from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: from airflow.providers.standard.utils.skipmixin import SkipMixin + from airflow.sdk import BaseOperator else: + from airflow.models.baseoperator import BaseOperator from airflow.models.skipmixin import SkipMixin if TYPE_CHECKING: diff --git a/providers/standard/src/airflow/providers/standard/operators/empty.py b/providers/standard/src/airflow/providers/standard/operators/empty.py index dc906b4b3c17b..3d91a574e9221 100644 --- a/providers/standard/src/airflow/providers/standard/operators/empty.py +++ b/providers/standard/src/airflow/providers/standard/operators/empty.py @@ -18,7 +18,12 @@ from typing import TYPE_CHECKING -from airflow.models.baseoperator import BaseOperator +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator +else: + from airflow.models.baseoperator import BaseOperator if TYPE_CHECKING: from airflow.sdk.definitions.context import Context diff --git a/providers/standard/src/airflow/providers/standard/operators/python.py b/providers/standard/src/airflow/providers/standard/operators/python.py index 5a8bd9e87b657..b8846259ab2fc 100644 --- a/providers/standard/src/airflow/providers/standard/operators/python.py +++ b/providers/standard/src/airflow/providers/standard/operators/python.py @@ -48,7 +48,6 @@ AirflowSkipException, DeserializingResultError, ) -from airflow.models.baseoperator import BaseOperator from airflow.models.variable import Variable from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv, write_python_script from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS @@ -61,7 +60,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.providers.standard.operators.branch import BaseBranchOperator from airflow.providers.standard.utils.skipmixin import SkipMixin + from airflow.sdk import BaseOperator else: + from airflow.models.baseoperator import BaseOperator from airflow.models.skipmixin import SkipMixin from airflow.operators.branch import BaseBranchOperator # type: ignore[no-redef] diff --git a/providers/standard/src/airflow/providers/standard/operators/smooth.py b/providers/standard/src/airflow/providers/standard/operators/smooth.py index 8aeb35a160045..33f089421f741 100644 --- a/providers/standard/src/airflow/providers/standard/operators/smooth.py +++ b/providers/standard/src/airflow/providers/standard/operators/smooth.py @@ -19,7 +19,12 @@ from typing import TYPE_CHECKING -from airflow.models.baseoperator import BaseOperator +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator +else: + from airflow.models.baseoperator import BaseOperator if TYPE_CHECKING: from airflow.sdk.definitions.context import Context diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index aa8a971a526b8..7c48f9a449b04 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -54,11 +54,10 @@ from airflow.models.taskinstancekey import TaskInstanceKey - try: + if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperator from airflow.sdk.definitions.context import Context - except ImportError: - # TODO: Remove once provider drops support for Airflow 2 + else: from airflow.models.baseoperator import BaseOperator from airflow.utils.context import Context diff --git a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py index 525662c970532..4351c6c8248af 100644 --- a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py +++ b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py @@ -22,13 +22,18 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: - from airflow.models.operator import Operator from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.types import RuntimeTaskInstanceProtocol + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.types import Operator + else: + from airflow.models.operator import Operator + # The key used by SkipMixin to store XCom data. XCOM_SKIPMIXIN_KEY = "skipmixin_key" @@ -40,8 +45,12 @@ def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]: - from airflow.models.baseoperator import BaseOperator - from airflow.models.mappedoperator import MappedOperator + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator + else: + from airflow.models.baseoperator import BaseOperator + from airflow.models.mappedoperator import MappedOperator return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))] diff --git a/providers/standard/tests/unit/standard/decorators/test_branch_external_python.py b/providers/standard/tests/unit/standard/decorators/test_branch_external_python.py index f0283c0307493..ceaf5df5fae1e 100644 --- a/providers/standard/tests/unit/standard/decorators/test_branch_external_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_branch_external_python.py @@ -79,19 +79,15 @@ def branch_operator(): branchoperator.set_downstream(task_2) dr = dag_maker.create_dagrun() - df.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) + dag_maker.run_ti("dummy_f", dr) if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: - branchoperator.operator.run( - start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True - ) + dag_maker.run_ti("branching", dr) assert exc_info.value.tasks == [(skipped_task_name, -1)] else: - branchoperator.operator.run( - start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True - ) - task_1.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) - task_2.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) + dag_maker.run_ti("branching", dr) + dag_maker.run_ti("task_1", dr) + dag_maker.run_ti("task_2", dr) tis = dr.get_task_instances() for ti in tis: diff --git a/providers/standard/tests/unit/standard/decorators/test_branch_python.py b/providers/standard/tests/unit/standard/decorators/test_branch_python.py index 3d8a46d7a37cd..c9475acec1dd0 100644 --- a/providers/standard/tests/unit/standard/decorators/test_branch_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_branch_python.py @@ -66,19 +66,15 @@ def branch_operator(): branchoperator.set_downstream(task_2) dr = dag_maker.create_dagrun() - df.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) + dag_maker.run_ti("dummy_f", dr) if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: - branchoperator.operator.run( - start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True - ) + dag_maker.run_ti("branching", dr) assert exc_info.value.tasks == [(skipped_task_name, -1)] else: - branchoperator.operator.run( - start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True - ) - task_1.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) - task_2.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) + dag_maker.run_ti("branching", dr) + dag_maker.run_ti("task_1", dr) + dag_maker.run_ti("task_2", dr) tis = dr.get_task_instances() for ti in tis: diff --git a/providers/standard/tests/unit/standard/decorators/test_branch_virtualenv.py b/providers/standard/tests/unit/standard/decorators/test_branch_virtualenv.py index 170916c21a31b..271855f4f4b24 100644 --- a/providers/standard/tests/unit/standard/decorators/test_branch_virtualenv.py +++ b/providers/standard/tests/unit/standard/decorators/test_branch_virtualenv.py @@ -95,19 +95,15 @@ def branch_operator(): branchoperator.set_downstream(task_2) dr = dag_maker.create_dagrun() - df.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) + dag_maker.run_ti("dummy_f", dr) if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: - branchoperator.operator.run( - start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True - ) + dag_maker.run_ti("branching", dr) assert exc_info.value.tasks == [(skipped_task_name, -1)] else: - branchoperator.operator.run( - start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True - ) - task_1.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) - task_2.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, ignore_ti_state=True) + dag_maker.run_ti("branching", dr) + dag_maker.run_ti("task_1", dr) + dag_maker.run_ti("task_2", dr) tis = dr.get_task_instances() for ti in tis: diff --git a/providers/standard/tests/unit/standard/decorators/test_external_python.py b/providers/standard/tests/unit/standard/decorators/test_external_python.py index 1b4bba68c2430..9c532616cedbf 100644 --- a/providers/standard/tests/unit/standard/decorators/test_external_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_external_python.py @@ -81,10 +81,11 @@ def f(): import dill # noqa: F401 with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dr = dag_maker.create_dagrun() + ti = dr.get_task_instances()[0] + ti.run() @pytest.mark.parametrize( "serializer", @@ -106,10 +107,10 @@ def f(): import dill # noqa: F401 with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() - - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + f() + dr = dag_maker.create_dagrun() + ti = dr.get_task_instances()[0] + ti.run() @pytest.mark.parametrize( "serializer", @@ -124,11 +125,12 @@ def f(): pass with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() + ti = dr.get_task_instances()[0] with pytest.raises(CalledProcessError): - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + ti.run() def test_exception_raises_error(self, dag_maker, venv_python): @task.external_python(python=venv_python) @@ -136,11 +138,12 @@ def f(): raise Exception with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() + ti = dr.get_task_instances()[0] with pytest.raises(CalledProcessError): - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + ti.run() @pytest.mark.parametrize( "serializer", @@ -159,10 +162,10 @@ def f(a, b, c=False, d=False): raise Exception with dag_maker(serialized=True): - ret = f(0, 1, c=True) - dag_maker.create_dagrun() - - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + f(0, 1, c=True) + dr = dag_maker.create_dagrun() + ti = dr.get_task_instances()[0] + ti.run() @pytest.mark.parametrize( "serializer", @@ -179,10 +182,11 @@ def f(): return None with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dr = dag_maker.create_dagrun() + ti = dr.get_task_instances()[0] + ti.run() @pytest.mark.parametrize( "serializer", @@ -199,10 +203,11 @@ def f(_): return None with dag_maker(serialized=True): - ret = f(datetime.datetime.now(tz=datetime.timezone.utc)) - dag_maker.create_dagrun() + f(datetime.datetime.now(tz=datetime.timezone.utc)) - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dr = dag_maker.create_dagrun() + ti = dr.get_task_instances()[0] + ti.run() @pytest.mark.parametrize( "serializer", @@ -222,13 +227,14 @@ def f(): return 1 with dag_maker(serialized=True) as dag: - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() assert len(dag.task_group.children) == 1 setup_task = dag.task_group.children["f"] assert setup_task.is_setup - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + ti = dr.get_task_instances()[0] + ti.run() @pytest.mark.parametrize( "serializer", @@ -248,13 +254,14 @@ def f(): return 1 with dag_maker(serialized=True) as dag: - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() assert len(dag.task_group.children) == 1 teardown_task = dag.task_group.children["f"] assert teardown_task.is_teardown - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + ti = dr.get_task_instances()[0] + ti.run() @pytest.mark.parametrize( "serializer", @@ -275,11 +282,12 @@ def f(): return 1 with dag_maker(serialized=True) as dag: - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() assert len(dag.task_group.children) == 1 teardown_task = dag.task_group.children["f"] assert teardown_task.is_teardown assert teardown_task.on_failure_fail_dagrun is on_failure_fail_dagrun - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + ti = dr.get_task_instances()[0] + ti.run() diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index ed8d5ef0cdb97..4bab68cc04789 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -27,10 +27,8 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.utils import timezone -from airflow.utils.state import State from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.trigger_rule import TriggerRule -from airflow.utils.types import DagRunType from airflow.utils.xcom import XCOM_RETURN_KEY from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS @@ -41,7 +39,6 @@ from airflow.sdk.bases.decorator import DecoratedMappedOperator from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput from airflow.sdk.definitions.mappedoperator import MappedOperator - from airflow.utils.types import DagRunTriggeredByType else: from airflow.decorators import setup, task as task_decorator, teardown from airflow.decorators.base import DecoratedMappedOperator # type: ignore[no-redef] @@ -317,56 +314,56 @@ class Test: def add_number(self, num: int) -> int: return self.num + num - def test_fail_multiple_outputs_key_type(self): + def test_fail_multiple_outputs_key_type(self, dag_maker): @task_decorator(multiple_outputs=True) def add_number(num: int): return {2: num} - with self.dag_non_serialized: - ret = add_number(2) + with dag_maker(): + add_number(2) - self.create_dag_run() + dr = dag_maker.create_dagrun() error_expected = AirflowException if (not AIRFLOW_V_3_0_PLUS or AIRFLOW_V_3_0_1) else TypeError with pytest.raises(error_expected): - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("add_number", dr) - def test_fail_multiple_outputs_no_dict(self): + def test_fail_multiple_outputs_no_dict(self, dag_maker): @task_decorator(multiple_outputs=True) def add_number(num: int): return num - with self.dag_non_serialized: - ret = add_number(2) + with dag_maker(): + add_number(2) - self.create_dag_run() + dr = dag_maker.create_dagrun() error_expected = AirflowException if (not AIRFLOW_V_3_0_PLUS or AIRFLOW_V_3_0_1) else TypeError with pytest.raises(error_expected): - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("add_number", dr) - def test_multiple_outputs_empty_dict(self): + def test_multiple_outputs_empty_dict(self, dag_maker): @task_decorator(multiple_outputs=True) def empty_dict(): return {} - with self.dag_non_serialized: - ret = empty_dict() + with dag_maker(): + empty_dict() - dr = self.create_dag_run() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dr = dag_maker.create_dagrun() + dag_maker.run_ti("empty_dict", dr) ti = dr.get_task_instances()[0] assert ti.xcom_pull() == {} - def test_multiple_outputs_return_none(self): + def test_multiple_outputs_return_none(self, dag_maker): @task_decorator(multiple_outputs=True) def test_func(): return - with self.dag_non_serialized: - ret = test_func() + with dag_maker(): + test_func() - dr = self.create_dag_run() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dr = dag_maker.create_dagrun() + dag_maker.run_ti("test_func", dr) ti = dr.get_task_instances()[0] assert ti.xcom_pull() is None @@ -472,7 +469,7 @@ def __do_run(): assert self.dag_non_serialized.task_ids[-1] == "__do_run__20" - def test_multiple_outputs(self): + def test_multiple_outputs(self, dag_maker): """Tests pushing multiple outputs as a dictionary""" @task_decorator(multiple_outputs=True) @@ -480,32 +477,11 @@ def return_dict(number: int): return {"number": number + 1, "43": 43} test_number = 10 - with self.dag_non_serialized: - ret = return_dict(test_number) - - v3_kwargs = ( - { - "run_after": DEFAULT_DATE, - "triggered_by": DagRunTriggeredByType.TEST, - "logical_date": DEFAULT_DATE, - } - if AIRFLOW_V_3_0_PLUS - else { - "execution_date": DEFAULT_DATE, - } - ) - dr = self.dag_non_serialized.create_dagrun( - run_id="test", - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - state=State.RUNNING, - data_interval=self.dag_non_serialized.timetable.infer_manual_data_interval( - run_after=DEFAULT_DATE - ), - **v3_kwargs, - ) + with dag_maker(): + return_dict(test_number) - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dr = dag_maker.create_dagrun() + dag_maker.run_ti("return_dict", dr) ti = dr.get_task_instances()[0] assert ti.xcom_pull(key="number") == test_number + 1 @@ -540,7 +516,7 @@ def test_apply_default(owner): ret = test_apply_default() assert "owner" in ret.operator.op_kwargs - def test_xcom_arg(self): + def test_xcom_arg(self, dag_maker): """Tests that returned key in XComArg is returned correctly""" @task_decorator @@ -553,35 +529,15 @@ def add_num(number: int, num2: int = 2): test_number = 10 - with self.dag_non_serialized: + with dag_maker(): bigger_number = add_2(test_number) ret = add_num(bigger_number, XComArg(bigger_number.operator)) - v3_kwargs = ( - { - "run_after": DEFAULT_DATE, - "triggered_by": DagRunTriggeredByType.TEST, - "logical_date": DEFAULT_DATE, - } - if AIRFLOW_V_3_0_PLUS - else { - "execution_date": DEFAULT_DATE, - } - ) - dr = self.dag_non_serialized.create_dagrun( - run_id="test", - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - state=State.RUNNING, - data_interval=self.dag_non_serialized.timetable.infer_manual_data_interval( - run_after=DEFAULT_DATE - ), - **v3_kwargs, - ) + dr = dag_maker.create_dagrun() - bigger_number.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("add_2", dr) + dag_maker.run_ti("add_num", dr) - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) ti_add_num = next(ti for ti in dr.get_task_instances() if ti.task_id == "add_num") assert ti_add_num.xcom_pull(key=ret.key) == (test_number + 2) * 2 @@ -680,7 +636,7 @@ def hello(): weights.append(task.priority_weight) assert weights == [0, 1, 2] - def test_python_callable_args_work_as_well_as_baseoperator_args(self): + def test_python_callable_args_work_as_well_as_baseoperator_args(self, dag_maker): """Tests that when looping that user provided pool, priority_weight etc is used""" @task_decorator(task_id="hello_task") @@ -691,14 +647,15 @@ def hello(x, y): print("Hello world", x, y) return x, y - with self.dag_non_serialized: + with dag_maker(): output = hello.override(task_id="mytask")(x=2, y=3) output2 = hello.override()(2, 3) # nothing overridden but should work + dr = dag_maker.create_dagrun() assert output.operator.op_kwargs == {"x": 2, "y": 3} assert output2.operator.op_args == (2, 3) - output.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - output2.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("mytask", dr) + dag_maker.run_ti("hello_task", dr) def test_mapped_decorator_shadow_context() -> None: diff --git a/providers/standard/tests/unit/standard/decorators/test_python_virtualenv.py b/providers/standard/tests/unit/standard/decorators/test_python_virtualenv.py index f67fa6a947518..8c26da96b7a5f 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python_virtualenv.py +++ b/providers/standard/tests/unit/standard/decorators/test_python_virtualenv.py @@ -57,10 +57,10 @@ def f(): import cloudpickle # noqa: F401 with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) @DILL_MARKER def test_add_dill(self, dag_maker): @@ -70,10 +70,10 @@ def f(): import dill # noqa: F401 with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) def test_no_requirements(self, dag_maker): """Tests that the python callable is invoked on task run.""" @@ -83,10 +83,10 @@ def f(): pass with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) @pytest.mark.parametrize( "serializer", @@ -105,10 +105,10 @@ def f(): raise Exception with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) @pytest.mark.parametrize( "serializer", @@ -128,10 +128,10 @@ def f(): import funcsigs # noqa: F401 with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) @pytest.mark.parametrize( "serializer", @@ -156,10 +156,10 @@ def f(): raise Exception with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) @pytest.mark.parametrize( "serializer", @@ -192,10 +192,10 @@ def f(): raise Exception with dag_maker(template_searchpath=tmp_path.as_posix(), serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) @pytest.mark.parametrize( "serializer, extra_requirements", @@ -217,10 +217,10 @@ def f(): import funcsigs # noqa: F401 with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) @pytest.mark.parametrize( "serializer", @@ -237,11 +237,11 @@ def f(): raise Exception with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() with pytest.raises(CalledProcessError): - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) @pytest.mark.parametrize( "serializer, extra_requirements", @@ -265,10 +265,10 @@ def f(): raise Exception with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) @pytest.mark.parametrize( "serializer, extra_requirements", @@ -287,10 +287,10 @@ def f(a, b, c=False, d=False): raise Exception with dag_maker(serialized=True): - ret = f(0, 1, c=True) - dag_maker.create_dagrun() + f(0, 1, c=True) + dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) def test_return_none(self, dag_maker): @task.virtualenv @@ -298,10 +298,10 @@ def f(): return None with dag_maker(serialized=True): - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) def test_nonimported_as_arg(self, dag_maker): @task.virtualenv @@ -309,10 +309,10 @@ def f(_): return None with dag_maker(serialized=True): - ret = f(datetime.datetime.now(tz=datetime.timezone.utc)) - dag_maker.create_dagrun() + f(datetime.datetime.now(tz=datetime.timezone.utc)) + dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) def test_marking_virtualenv_python_task_as_setup(self, dag_maker): @setup @@ -321,13 +321,13 @@ def f(): return 1 with dag_maker(serialized=True) as dag: - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() assert len(dag.task_group.children) == 1 setup_task = dag.task_group.children["f"] assert setup_task.is_setup - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) def test_marking_virtualenv_python_task_as_teardown(self, dag_maker): @teardown @@ -336,13 +336,13 @@ def f(): return 1 with dag_maker(serialized=True) as dag: - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() assert len(dag.task_group.children) == 1 teardown_task = dag.task_group.children["f"] assert teardown_task.is_teardown - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False]) def test_marking_virtualenv_python_task_as_teardown_with_on_failure_fail( @@ -354,14 +354,14 @@ def f(): return 1 with dag_maker(serialized=True) as dag: - ret = f() - dag_maker.create_dagrun() + f() + dr = dag_maker.create_dagrun() assert len(dag.task_group.children) == 1 teardown_task = dag.task_group.children["f"] assert teardown_task.is_teardown assert teardown_task.on_failure_fail_dagrun is on_failure_fail_dagrun - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("f", dr) def test_invalid_annotation(self, dag_maker): import uuid @@ -377,10 +377,10 @@ def in_venv(value: dict[str, _Invalid]) -> _Invalid: return value["unique_id"] with dag_maker(serialized=True): - ret = in_venv(value) + in_venv(value) dr = dag_maker.create_dagrun() - ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date) + dag_maker.run_ti("in_venv", dr) ti = dr.get_task_instances()[0] assert ti.state == TaskInstanceState.SUCCESS diff --git a/providers/standard/tests/unit/standard/decorators/test_short_circuit.py b/providers/standard/tests/unit/standard/decorators/test_short_circuit.py index 3ead1c252bbb3..2df79a0a199dd 100644 --- a/providers/standard/tests/unit/standard/decorators/test_short_circuit.py +++ b/providers/standard/tests/unit/standard/decorators/test_short_circuit.py @@ -64,8 +64,8 @@ def short_circuit(condition): dr = dag_maker.create_dagrun() - for t in dag_maker.dag.tasks: - t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + for dag_task in dag_maker.dag.tasks: + dag_maker.run_ti(dag_task.task_id, dr, ignore_ti_state=True) task_state_mapping = { "short_circuit_false": State.SUCCESS, @@ -180,10 +180,10 @@ def multiple_output(): return {"x": 1, "y": 2} with dag_maker(serialized=True): - ret = multiple_output() + multiple_output() dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("multiple_output", dr) ti = dr.get_task_instances()[0] assert ti.xcom_pull() == {"x": 1, "y": 2} @@ -194,9 +194,9 @@ def empty_dict(): return {} with dag_maker(serialized=True): - ret = empty_dict() + empty_dict() dr = dag_maker.create_dagrun() - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("empty_dict", dr) ti = dr.get_task_instances()[0] assert ti.xcom_pull() == {} diff --git a/providers/standard/tests/unit/standard/operators/test_bash.py b/providers/standard/tests/unit/standard/operators/test_bash.py index fe33689d00e33..976ff48d9ef57 100644 --- a/providers/standard/tests/unit/standard/operators/test_bash.py +++ b/providers/standard/tests/unit/standard/operators/test_bash.py @@ -93,7 +93,7 @@ def test_echo_env_variables( serialized=True, ): tmp_file = tmp_path / "testfile" - task = BashOperator( + BashOperator( task_id="echo_env_vars", bash_command=f"echo $AIRFLOW_HOME>> {tmp_file};" f"echo $PYTHONPATH>> {tmp_file};" @@ -106,7 +106,7 @@ def test_echo_env_variables( ) logical_date = utc_now - dag_maker.create_dagrun( + dr = dag_maker.create_dagrun( run_type=DagRunType.MANUAL, logical_date=logical_date, start_date=utc_now, @@ -117,7 +117,7 @@ def test_echo_env_variables( with mock.patch.dict( "os.environ", {"AIRFLOW_HOME": "MY_PATH_TO_AIRFLOW_HOME", "PYTHONPATH": "AWESOME_PYTHONPATH"} ): - task.run(utc_now, utc_now, ignore_first_depends_on_past=True, ignore_ti_state=True) + dag_maker.run_ti("echo_env_vars", dr) assert expected == tmp_file.read_text() @@ -249,14 +249,14 @@ def test_bash_operator_kill(self, dag_maker): sleep_time = f"100{os.getpid()}" with dag_maker(serialized=True): - op = BashOperator( + BashOperator( task_id="test_bash_operator_kill", execution_timeout=timedelta(microseconds=25), bash_command=f"/bin/bash -c 'sleep {sleep_time}'", ) - dag_maker.create_dagrun() + dr = dag_maker.create_dagrun() with pytest.raises(AirflowTaskTimeout): - op.run() + dag_maker.run_ti("test_bash_operator_kill", dr) sleep(2) for proc in psutil.process_iter(): if proc.cmdline() == ["sleep", sleep_time]: diff --git a/providers/standard/tests/unit/standard/operators/test_branch_operator.py b/providers/standard/tests/unit/standard/operators/test_branch_operator.py index 821e7cfb9c675..670ce77415ba0 100644 --- a/providers/standard/tests/unit/standard/operators/test_branch_operator.py +++ b/providers/standard/tests/unit/standard/operators/test_branch_operator.py @@ -74,15 +74,15 @@ def test_without_dag_run(self, dag_maker): branch_op = ChooseBranchOne(task_id="make_choice") branch_1.set_upstream(branch_op) branch_2.set_upstream(branch_op) - dag_maker.create_dagrun(**triggered_by_kwargs) + dr = dag_maker.create_dagrun(**triggered_by_kwargs) if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) assert exc_info.value.tasks == [("branch_2", -1)] else: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) ti_date = TI.logical_date if AIRFLOW_V_3_0_PLUS else TI.execution_date @@ -90,7 +90,6 @@ def test_without_dag_run(self, dag_maker): if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": - # should exist with state None assert ti.state == State.NONE elif ti.task_id == "branch_2": assert ti.state == State.SKIPPED @@ -114,15 +113,15 @@ def test_branch_list_without_dag_run(self, dag_maker): branch_1.set_upstream(branch_op) branch_2.set_upstream(branch_op) branch_3.set_upstream(branch_op) - dag_maker.create_dagrun(**triggered_by_kwargs) + dr = dag_maker.create_dagrun(**triggered_by_kwargs) if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) assert exc_info.value.tasks == [("branch_3", -1)] else: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) expected = { "make_choice": State.SUCCESS, @@ -154,7 +153,7 @@ def test_with_dag_run(self, dag_maker): branch_1.set_upstream(branch_op) branch_2.set_upstream(branch_op) if AIRFLOW_V_3_0_1: - dag_maker.create_dagrun( + dr = dag_maker.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), logical_date=DEFAULT_DATE, @@ -164,11 +163,11 @@ def test_with_dag_run(self, dag_maker): ) with pytest.raises(DownstreamTasksSkipped) as exc_info: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) assert exc_info.value.tasks == [("branch_2", -1)] else: - dag_maker.create_dagrun( + dr = dag_maker.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, @@ -177,7 +176,7 @@ def test_with_dag_run(self, dag_maker): **triggered_by_kwargs, ) - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) expected = { "make_choice": State.SUCCESS, @@ -209,7 +208,7 @@ def test_with_skip_in_branch_downstream_dependencies(self, dag_maker): branch_op >> branch_2 if AIRFLOW_V_3_0_PLUS: - dag_maker.create_dagrun( + dr = dag_maker.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), logical_date=DEFAULT_DATE, @@ -218,7 +217,7 @@ def test_with_skip_in_branch_downstream_dependencies(self, dag_maker): **triggered_by_kwargs, ) else: - dag_maker.create_dagrun( + dr = dag_maker.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, @@ -227,7 +226,7 @@ def test_with_skip_in_branch_downstream_dependencies(self, dag_maker): **triggered_by_kwargs, ) - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) expected = { "make_choice": State.SUCCESS, @@ -276,12 +275,12 @@ def test_xcom_push(self, dag_maker): if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as exc_info: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) assert exc_info.value.tasks == [("branch_2", -1)] else: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - branch_op_ti = dr.get_task_instance(branch_op.task_id) + dag_maker.run_ti("make_choice", dr) + branch_op_ti = dr.get_task_instance("make_choice") assert branch_op_ti.xcom_pull(task_ids="make_choice", key=XCOM_SKIPMIXIN_KEY) == { XCOM_SKIPMIXIN_FOLLOWED: ["branch_1"] } @@ -306,7 +305,7 @@ def test_with_dag_run_task_groups(self, dag_maker): branch_3.set_upstream(branch_op) if AIRFLOW_V_3_0_1: - dag_maker.create_dagrun( + dr = dag_maker.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), logical_date=DEFAULT_DATE, @@ -316,11 +315,11 @@ def test_with_dag_run_task_groups(self, dag_maker): ) with pytest.raises(DownstreamTasksSkipped) as exc_info: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) assert set(exc_info.value.tasks) == {("branch_1", -1), ("branch_2", -1)} else: - dag_maker.create_dagrun( + dr = dag_maker.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, @@ -328,7 +327,7 @@ def test_with_dag_run_task_groups(self, dag_maker): data_interval=DataInterval(DEFAULT_DATE, DEFAULT_DATE), **triggered_by_kwargs, ) - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) ti_date = TI.logical_date if AIRFLOW_V_3_0_PLUS else TI.execution_date diff --git a/providers/standard/tests/unit/standard/operators/test_datetime.py b/providers/standard/tests/unit/standard/operators/test_datetime.py index 0c6e40381793f..afc1d8ff012e2 100644 --- a/providers/standard/tests/unit/standard/operators/test_datetime.py +++ b/providers/standard/tests/unit/standard/operators/test_datetime.py @@ -56,6 +56,7 @@ def setup_class(cls): @pytest.fixture(autouse=True) def base_tests_setup(self, dag_maker): + self.dag_maker = dag_maker # Store dag_maker for use in test methods with dag_maker( "branch_datetime_operator_test", default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, @@ -128,10 +129,10 @@ def test_branch_datetime_operator_falls_within_range(self, target_lower, target_ from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.dag_maker.run_ti("datetime_branch", self.dr) assert exc_info.value.tasks == [("branch_2", -1)] else: - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.dag_maker.run_ti("datetime_branch", self.dr) self._assert_task_ids_match_states( { @@ -161,12 +162,12 @@ def test_branch_datetime_operator_falls_outside_range(self, date, target_lower, from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info, time_machine.travel(date): - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.dag_maker.run_ti("datetime_branch", self.dr) assert exc_info.value.tasks == [("branch_1", -1)] else: with time_machine.travel(date): - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.dag_maker.run_ti("datetime_branch", self.dr) self._assert_task_ids_match_states( { @@ -187,11 +188,11 @@ def test_branch_datetime_operator_upper_comparison_within_range(self, target_upp from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.dag_maker.run_ti("datetime_branch", self.dr) assert exc_info.value.tasks == [("branch_2", -1)] else: - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.dag_maker.run_ti("datetime_branch", self.dr) self._assert_task_ids_match_states( { @@ -212,11 +213,11 @@ def test_branch_datetime_operator_lower_comparison_within_range(self, target_low from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.dag_maker.run_ti("datetime_branch", self.dr) assert exc_info.value.tasks == [("branch_2", -1)] else: - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.dag_maker.run_ti("datetime_branch", self.dr) self._assert_task_ids_match_states( { @@ -237,11 +238,11 @@ def test_branch_datetime_operator_upper_comparison_outside_range(self, target_up from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.dag_maker.run_ti("datetime_branch", self.dr) assert exc_info.value.tasks == [("branch_1", -1)] else: - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.dag_maker.run_ti("datetime_branch", self.dr) self._assert_task_ids_match_states( { @@ -262,11 +263,11 @@ def test_branch_datetime_operator_lower_comparison_outside_range(self, target_lo from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.dag_maker.run_ti("datetime_branch", self.dr) assert exc_info.value.tasks == [("branch_1", -1)] else: - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.dag_maker.run_ti("datetime_branch", self.dr) self._assert_task_ids_match_states( { @@ -299,11 +300,11 @@ def test_branch_datetime_operator_use_task_logical_date(self, dag_maker, target_ from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: - self.branch_op.run(start_date=in_between_date, end_date=in_between_date) + self.dag_maker.run_ti("datetime_branch", self.dr) assert exc_info.value.tasks == [("branch_2", -1)] else: - self.branch_op.run(start_date=in_between_date, end_date=in_between_date) + self.dag_maker.run_ti("datetime_branch", self.dr) self._assert_task_ids_match_states( { diff --git a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py index fce99a64b8278..066b15f9cf39f 100644 --- a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py +++ b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py @@ -83,9 +83,9 @@ def test_run(self, dag_maker): with dag_maker( default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, schedule=INTERVAL, serialized=True ): - task = LatestOnlyOperator(task_id="latest") - dag_maker.create_dagrun() - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + LatestOnlyOperator(task_id="latest") + dr = dag_maker.create_dagrun() + dag_maker.run_ti("latest", dr) def test_skipping_non_latest(self, dag_maker): with dag_maker( @@ -166,7 +166,8 @@ def test_skipping_non_latest(self, dag_maker): latest_ti2.task = latest_task latest_ti2.run() else: - latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) + for dr in [dr0, dr1, dr2]: + dag_maker.run_ti("latest", dr) if AIRFLOW_V_3_0_PLUS: date_getter = operator.attrgetter("logical_date") @@ -182,9 +183,10 @@ def test_skipping_non_latest(self, dag_maker): } # Verify the state of the other downstream tasks - downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE) + for dr in [dr0, dr1, dr2]: + dag_maker.run_ti("downstream", dr) + dag_maker.run_ti("downstream_2", dr) + dag_maker.run_ti("downstream_3", dr) downstream_instances = get_task_instances("downstream") exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances} @@ -251,9 +253,12 @@ def test_not_skipping_manual(self, dag_maker): **triggered_by_kwargs, ) - latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) + # Get all created dag runs and run tasks for each + all_drs = dag_maker.session.query(DagRun).filter_by(dag_id=dag_maker.dag.dag_id).all() + for dr in all_drs: + dag_maker.run_ti("latest", dr) + dag_maker.run_ti("downstream", dr) + dag_maker.run_ti("downstream_2", dr) latest_instances = get_task_instances("latest") if AIRFLOW_V_3_0_PLUS: diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index a63bb059e6863..184045f18a345 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -28,7 +28,7 @@ import warnings from collections import namedtuple from collections.abc import Generator -from datetime import date, datetime, timedelta, timezone as _timezone +from datetime import date, datetime, timezone as _timezone from functools import partial from importlib.util import find_spec from pathlib import Path @@ -47,8 +47,12 @@ AirflowProviderDeprecationWarning, DeserializingResultError, ) -from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import DAG +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator +else: + from airflow.models.baseoperator import BaseOperator from airflow.models.taskinstance import TaskInstance, clear_task_instances, set_current_context from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import ( @@ -74,6 +78,7 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: + from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.utils.context import Context @@ -174,8 +179,8 @@ def run_as_operator(self, fn, **kwargs): clear_db_runs() with self.dag_maker(self.dag_id, template_searchpath=TEMPLATE_SEARCHPATH, serialized=True): task = self.opcls(task_id=self.task_id, python_callable=fn, **self.default_kwargs(**kwargs)) - self.dag_maker.create_dagrun() - task.run(start_date=self.default_date, end_date=self.default_date) + dr = self.dag_maker.create_dagrun() + self.dag_maker.run_ti(self.task_id, dr) clear_db_runs() return task @@ -405,13 +410,13 @@ def f(): branch_op = self.opcls(task_id=self.task_id, python_callable=f, **self.default_kwargs()) branch_op >> [self.branch_1, self.branch_2] - dr = self.create_dag_run() + dr = self.dag_maker.create_dagrun() if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as dts: - branch_op.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti(self.task_id, dr) assert dts.value.tasks == [("branch_2", -1)] else: - branch_op.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti(self.task_id, dr) self.assert_expected_task_states( dr, {self.task_id: State.SUCCESS, "branch_1": State.NONE, "branch_2": State.SKIPPED} ) @@ -427,8 +432,8 @@ def f(): branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 - dr = self.create_dag_run() - branch_op.run(start_date=self.default_date, end_date=self.default_date) + dr = self.dag_maker.create_dagrun() + self.dag_maker.run_ti(self.task_id, dr) self.assert_expected_task_states( dr, {self.task_id: State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE} ) @@ -444,13 +449,13 @@ def f(): branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 - dr = self.create_dag_run() + dr = self.dag_maker.create_dagrun() if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as dts: branch_op.run(start_date=self.default_date, end_date=self.default_date) assert dts.value.tasks == [("branch_1", -1)] else: - branch_op.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti(branch_op.task_id, dr) self.assert_expected_task_states( dr, {self.task_id: State.SUCCESS, "branch_1": State.SKIPPED, "branch_2": State.NONE} ) @@ -487,9 +492,9 @@ def f(): branch_2_ti.run() assert branch_2_ti.state == TaskInstanceState.SKIPPED else: - branch_op.run(start_date=self.default_date, end_date=self.default_date) + dag_maker.run_ti(branch_op.task_id, dr) for task in branches: - task.run(start_date=self.default_date, end_date=self.default_date) + dag_maker.run_ti(task.task_id, dr) expected_states = { self.task_id: State.SUCCESS, @@ -510,7 +515,7 @@ def f(): # Run the cleared tasks again. for task in branches: - task.run(start_date=self.default_date, end_date=self.default_date) + dag_maker.run_ti(task.task_id, dr) # Check if the states are correct after children tasks are cleared. self.assert_expected_task_states(dr, expected_states) @@ -714,7 +719,7 @@ def test_short_circuiting( Checking the behavior of the ShortCircuitOperator in several scenarios enabling/disabling the skipping of downstream tasks, both short-circuiting modes, and various trigger rules of downstream tasks. """ - with self.dag_non_serialized: + with self.dag_maker(self.dag_id, template_searchpath=TEMPLATE_SEARCHPATH, serialized=True): short_circuit = ShortCircuitOperator( task_id="short_circuit", python_callable=lambda: callable_return, @@ -723,21 +728,21 @@ def test_short_circuiting( short_circuit >> self.op1 >> self.op2 self.op2.trigger_rule = test_trigger_rule - dr = self.create_dag_run() + dr = self.dag_maker.create_dagrun() if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped if expected_skipped_tasks: with pytest.raises(DownstreamTasksSkipped) as exc_info: - short_circuit.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti("short_circuit", dr) assert set(exc_info.value.tasks) == set(expected_skipped_tasks) else: - assert short_circuit.run(start_date=self.default_date, end_date=self.default_date) is None + assert self.dag_maker.run_ti("short_circuit", dr) is None else: - short_circuit.run(start_date=self.default_date, end_date=self.default_date) - self.op1.run(start_date=self.default_date, end_date=self.default_date) - self.op2.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti("short_circuit", dr) + self.dag_maker.run_ti("op1", dr) + self.dag_maker.run_ti("op2", dr) assert short_circuit.ignore_downstream_trigger_rules == test_ignore_downstream_trigger_rules assert short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS @@ -772,9 +777,9 @@ def test_clear_skipped_downstream_task(self): op1_ti.run() assert op1_ti.state == TaskInstanceState.SKIPPED else: - short_circuit.run(start_date=self.default_date, end_date=self.default_date) - self.op1.run(start_date=self.default_date, end_date=self.default_date) - self.op2.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti(self.task_id, dr) + self.dag_maker.run_ti(self.op1.task_id, dr) + self.dag_maker.run_ti(self.op2.task_id, dr) assert short_circuit.ignore_downstream_trigger_rules assert short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS assert self.op1.trigger_rule == TriggerRule.ALL_SUCCESS @@ -796,7 +801,7 @@ def test_clear_skipped_downstream_task(self): clear_task_instances( [ti for ti in tis if ti.task_id == "op1"], session=session, dag=short_circuit.dag ) - self.op1.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti("op1", dr) self.assert_expected_task_states(dr, expected_states) def test_xcom_push(self): @@ -809,29 +814,29 @@ def test_xcom_push(self): task_id="do_not_push_xcom_from_shortcircuit", python_callable=lambda: False ) - dr = self.create_dag_run() - short_op_push_xcom.run(start_date=self.default_date, end_date=self.default_date) - short_op_no_push_xcom.run(start_date=self.default_date, end_date=self.default_date) + dr = self.dag_maker.create_dagrun() + self.dag_maker.run_ti("push_xcom_from_shortcircuit", dr) + self.dag_maker.run_ti("do_not_push_xcom_from_shortcircuit", dr) tis = dr.get_task_instances() assert tis[0].xcom_pull(task_ids=short_op_push_xcom.task_id, key="return_value") == "signature" assert tis[0].xcom_pull(task_ids=short_op_no_push_xcom.task_id, key="return_value") is False def test_xcom_push_skipped_tasks(self): - with self.dag_non_serialized: + with self.dag_maker(self.dag_id, template_searchpath=TEMPLATE_SEARCHPATH, serialized=True): short_op_push_xcom = ShortCircuitOperator( task_id="push_xcom_from_shortcircuit", python_callable=lambda: False ) empty_task = EmptyOperator(task_id="empty_task") short_op_push_xcom >> empty_task - dr = self.create_dag_run() + dr = self.dag_maker.create_dagrun() if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped): short_op_push_xcom.run(start_date=self.default_date, end_date=self.default_date) else: - short_op_push_xcom.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti("push_xcom_from_shortcircuit", dr) tis = dr.get_task_instances() assert tis[0].xcom_pull(task_ids=short_op_push_xcom.task_id, key="skipmixin_key") == { "skipped": ["empty_task"] @@ -1739,14 +1744,13 @@ def f(): branch_op = self.opcls(task_id=self.task_id, python_callable=f, **self.default_kwargs()) branch_op >> [self.branch_1, self.branch_2] - dr = self.create_dag_run() + dr = self.dag_maker.create_dagrun() if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as dts: - branch_op.run(start_date=self.default_date, end_date=self.default_date) - + self.dag_maker.run_ti(self.task_id, dr) assert dts.value.tasks == [("branch_2", -1)] else: - branch_op.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti(self.task_id, dr) self.assert_expected_task_states( dr, {self.task_id: State.SUCCESS, "branch_1": State.NONE, "branch_2": State.SKIPPED} ) @@ -1762,8 +1766,8 @@ def f(): branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 - dr = self.create_dag_run() - branch_op.run(start_date=self.default_date, end_date=self.default_date) + dr = self.dag_maker.create_dagrun() + self.dag_maker.run_ti(branch_op.task_id, dr) self.assert_expected_task_states( dr, {self.task_id: State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE} ) @@ -1779,15 +1783,15 @@ def f(): branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 - dr = self.create_dag_run() + dr = self.dag_maker.create_dagrun() if AIRFLOW_V_3_0_1: with pytest.raises(DownstreamTasksSkipped) as dts: - branch_op.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti(branch_op.task_id, dr) assert dts.value.tasks == [("branch_1", -1)] else: - branch_op.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti(branch_op.task_id, dr) self.assert_expected_task_states( dr, {self.task_id: State.SUCCESS, "branch_1": State.SKIPPED, "branch_2": State.NONE} ) @@ -1806,7 +1810,7 @@ def f(): branches = [self.branch_1, self.branch_2] branch_op >> branches - dr = self.create_dag_run() + dr = self.dag_maker.create_dagrun() if AIRFLOW_V_3_0_1: from airflow.exceptions import DownstreamTasksSkipped @@ -1827,9 +1831,9 @@ def f(): branch_2_ti.run() assert branch_2_ti.state == TaskInstanceState.SKIPPED else: - branch_op.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti(branch_op.task_id, dr) for task in branches: - task.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti(task.task_id, dr) expected_states = { self.task_id: State.SUCCESS, @@ -1850,7 +1854,7 @@ def f(): # Run the cleared tasks again. for task in branches: - task.run(start_date=self.default_date, end_date=self.default_date) + self.dag_maker.run_ti(task.task_id, dr) # Check if the states are correct after children tasks are cleared. self.assert_expected_task_states(dr, expected_states) @@ -2000,35 +2004,27 @@ def clear_db(): clear_db_runs() -DEFAULT_ARGS = { - "owner": "test", - "depends_on_past": True, - "start_date": datetime(2022, 1, 1), - "end_date": datetime.today(), - "retries": 1, - "retry_delay": timedelta(minutes=1), -} - - @pytest.mark.usefixtures("clear_db") class TestCurrentContextRuntime: - def test_context_in_task(self): - with DAG(dag_id="assert_context_dag", default_args=DEFAULT_ARGS, schedule="@once"): + def test_context_in_task(self, dag_maker): + with dag_maker(dag_id="assert_context_dag", serialized=True): op = MyContextAssertOperator(task_id="assert_context") - if AIRFLOW_V_3_0_1: - with pytest.warns(AirflowProviderDeprecationWarning): - op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) - else: - op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) + dr = dag_maker.create_dagrun() + if AIRFLOW_V_3_0_PLUS: + with pytest.warns(AirflowProviderDeprecationWarning): + dag_maker.run_ti(op.task_id, dr) + else: + dag_maker.run_ti(op.task_id, dr) - def test_get_context_in_old_style_context_task(self): - with DAG(dag_id="edge_case_context_dag", default_args=DEFAULT_ARGS, schedule="@once"): + def test_get_context_in_old_style_context_task(self, dag_maker): + with dag_maker(dag_id="assert_context_dag", serialized=True): op = PythonOperator(python_callable=get_all_the_context, task_id="get_all_the_context") - if AIRFLOW_V_3_0_1: - with pytest.warns(AirflowProviderDeprecationWarning): - op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) - else: - op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) + dr = dag_maker.create_dagrun() + if AIRFLOW_V_3_0_PLUS: + with pytest.warns(AirflowProviderDeprecationWarning): + dag_maker.run_ti(op.task_id, dr) + else: + dag_maker.run_ti(op.task_id, dr) @pytest.mark.need_serialized_dag(False) diff --git a/providers/standard/tests/unit/standard/operators/test_weekday.py b/providers/standard/tests/unit/standard/operators/test_weekday.py index 0372669c96179..bd49e45b1a701 100644 --- a/providers/standard/tests/unit/standard/operators/test_weekday.py +++ b/providers/standard/tests/unit/standard/operators/test_weekday.py @@ -120,11 +120,11 @@ def test_branch_follow_true(self, weekday, dag_maker): from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) assert exc_info.value.tasks == [("branch_3", -1)] else: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) self._assert_task_ids_match_states( dr, @@ -166,10 +166,10 @@ def test_branch_follow_true_with_logical_date(self, dag_maker): from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) assert exc_info.value.tasks == [("branch_2", -1)] else: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) self._assert_task_ids_match_states( dr, @@ -235,11 +235,11 @@ def test_branch_follow_false(self, dag_maker): from airflow.exceptions import DownstreamTasksSkipped with pytest.raises(DownstreamTasksSkipped) as exc_info: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) assert exc_info.value.tasks == [("branch_1", -1)] else: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) self._assert_task_ids_match_states( dr, @@ -345,7 +345,7 @@ def test_branch_xcom_push_true_branch(self, dag_maker): assert exc_info.value.tasks == [("branch_2", -1)] else: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dag_maker.run_ti("make_choice", dr) assert branch_op_ti.xcom_pull(task_ids="make_choice", key=XCOM_SKIPMIXIN_KEY) == { XCOM_SKIPMIXIN_FOLLOWED: ["branch_1"] diff --git a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py index 78d2b9434663c..fc1f982c8e227 100644 --- a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py +++ b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py @@ -33,7 +33,12 @@ TaskDeferred, ) from airflow.models import DagBag, DagRun, TaskInstance -from airflow.models.baseoperator import BaseOperator +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator +else: + from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel from airflow.models.xcom_arg import XComArg diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py b/task-sdk/src/airflow/sdk/bases/decorator.py index 6780fe29fe27d..c1063774220de 100644 --- a/task-sdk/src/airflow/sdk/bases/decorator.py +++ b/task-sdk/src/airflow/sdk/bases/decorator.py @@ -249,10 +249,9 @@ def execute(self, context: Context): if isinstance(arg, Asset): self.inlets.append(arg) return_value = super().execute(context) - # TODO(potiuk) - this xcom push is temporary and should be fixed - return self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push) # type: ignore[attr-defined] + return self._handle_output(return_value=return_value) - def _handle_output(self, return_value: Any, context: Context, xcom_push: Callable): + def _handle_output(self, return_value: Any): """ Handle logic for whether a decorator needs to push a single return value or multiple return values. From 0442579d5f231215b5e4e14f2fb171ed9213ccb2 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 27 Jun 2025 14:54:42 +0530 Subject: [PATCH 2/3] fixup! Replace `models.BaseOperator` to Task SDK one for Standard Provider --- devel-common/src/tests_common/pytest_plugin.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 4a6f86f7ac66c..ba3e209901938 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1120,6 +1120,8 @@ def run_ti(self, task_id, dag_run=None, dag_run_kwargs=None, **kwargs): Returns the created TaskInstance. """ + from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS + if dag_run is None: if dag_run_kwargs is None: dag_run_kwargs = {} @@ -1131,7 +1133,21 @@ def run_ti(self, task_id, dag_run=None, dag_run_kwargs=None, **kwargs): f"Task instance with task_id '{task_id}' not found in dag run. " f"Available task_ids: {available_task_ids}" ) - ti.refresh_from_task(self.dag.get_task(ti.task_id)) + task = self.dag.get_task(ti.task_id) + + if not AIRFLOW_V_3_1_PLUS: + # Airflow <3.1 has a bug for DecoratedOperator has an unused signature for + # `DecoratedOperator._handle_output` for xcom_push + # This worked for `models.BaseOperator` since it had xcom_push method but for + # `airflow.sdk.BaseOperator`, this does not exist, so this returns an AttributeError + # Since this is an unused attribute anyway, we just monkey patch it with a lambda. + # Error otherwise: + # /usr/local/lib/python3.11/site-packages/airflow/sdk/bases/decorator.py:253: in execute + # return self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push) + # ^^^^^^^^^^^^^^ + # E AttributeError: '_PythonDecoratedOperator' object has no attribute 'xcom_push' + task.xcom_push = lambda *args, **kwargs: None + ti.refresh_from_task(task) ti.run(**kwargs) return ti From 27642e267923fbce6c54eb15d408391a003bdb20 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 27 Jun 2025 16:58:36 +0530 Subject: [PATCH 3/3] fixup! fixup! Replace `models.BaseOperator` to Task SDK one for Standard Provider --- .../src/airflow/providers/standard/operators/bash.py | 9 ++++++--- .../src/airflow/providers/standard/operators/branch.py | 9 ++++++--- .../src/airflow/providers/standard/operators/empty.py | 4 ++-- .../src/airflow/providers/standard/operators/python.py | 9 ++++++--- .../src/airflow/providers/standard/utils/skipmixin.py | 6 +++--- .../src/airflow/providers/standard/version_compat.py | 1 + 6 files changed, 24 insertions(+), 14 deletions(-) diff --git a/providers/standard/src/airflow/providers/standard/operators/bash.py b/providers/standard/src/airflow/providers/standard/operators/bash.py index 04a2bfedf3be3..3195b33456c59 100644 --- a/providers/standard/src/airflow/providers/standard/operators/bash.py +++ b/providers/standard/src/airflow/providers/standard/operators/bash.py @@ -26,13 +26,16 @@ from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.standard.hooks.subprocess import SubprocessHook, SubprocessResult, working_directory -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS -if AIRFLOW_V_3_0_PLUS: +if AIRFLOW_V_3_1_PLUS: from airflow.sdk import BaseOperator - from airflow.sdk.execution_time.context import context_to_airflow_vars else: from airflow.models.baseoperator import BaseOperator + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.context import context_to_airflow_vars +else: from airflow.utils.operator_helpers import context_to_airflow_vars # type: ignore[no-redef, attr-defined] if TYPE_CHECKING: diff --git a/providers/standard/src/airflow/providers/standard/operators/branch.py b/providers/standard/src/airflow/providers/standard/operators/branch.py index 3db925ace82aa..44c958c4f2eea 100644 --- a/providers/standard/src/airflow/providers/standard/operators/branch.py +++ b/providers/standard/src/airflow/providers/standard/operators/branch.py @@ -22,13 +22,16 @@ from collections.abc import Iterable from typing import TYPE_CHECKING -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS -if AIRFLOW_V_3_0_PLUS: - from airflow.providers.standard.utils.skipmixin import SkipMixin +if AIRFLOW_V_3_1_PLUS: from airflow.sdk import BaseOperator else: from airflow.models.baseoperator import BaseOperator + +if AIRFLOW_V_3_0_PLUS: + from airflow.providers.standard.utils.skipmixin import SkipMixin +else: from airflow.models.skipmixin import SkipMixin if TYPE_CHECKING: diff --git a/providers/standard/src/airflow/providers/standard/operators/empty.py b/providers/standard/src/airflow/providers/standard/operators/empty.py index 3d91a574e9221..4cbd6254d4a9e 100644 --- a/providers/standard/src/airflow/providers/standard/operators/empty.py +++ b/providers/standard/src/airflow/providers/standard/operators/empty.py @@ -18,9 +18,9 @@ from typing import TYPE_CHECKING -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS -if AIRFLOW_V_3_0_PLUS: +if AIRFLOW_V_3_1_PLUS: from airflow.sdk import BaseOperator else: from airflow.models.baseoperator import BaseOperator diff --git a/providers/standard/src/airflow/providers/standard/operators/python.py b/providers/standard/src/airflow/providers/standard/operators/python.py index b8846259ab2fc..54ae1bc3200a6 100644 --- a/providers/standard/src/airflow/providers/standard/operators/python.py +++ b/providers/standard/src/airflow/providers/standard/operators/python.py @@ -50,19 +50,22 @@ ) from airflow.models.variable import Variable from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv, write_python_script -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS from airflow.utils import hashlib_wrapper from airflow.utils.context import context_copy_partial, context_merge from airflow.utils.file import get_unique_dag_module_name from airflow.utils.operator_helpers import KeywordParameters from airflow.utils.process_utils import execute_in_subprocess +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import BaseOperator +else: + from airflow.models.baseoperator import BaseOperator + if AIRFLOW_V_3_0_PLUS: from airflow.providers.standard.operators.branch import BaseBranchOperator from airflow.providers.standard.utils.skipmixin import SkipMixin - from airflow.sdk import BaseOperator else: - from airflow.models.baseoperator import BaseOperator from airflow.models.skipmixin import SkipMixin from airflow.operators.branch import BaseBranchOperator # type: ignore[no-redef] diff --git a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py index 4351c6c8248af..5cfb577cd2f7c 100644 --- a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py +++ b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py @@ -22,14 +22,14 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.types import RuntimeTaskInstanceProtocol - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_1_PLUS: from airflow.sdk.types import Operator else: from airflow.models.operator import Operator @@ -45,7 +45,7 @@ def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]: - if AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_1_PLUS: from airflow.sdk import BaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator else: diff --git a/providers/standard/src/airflow/providers/standard/version_compat.py b/providers/standard/src/airflow/providers/standard/version_compat.py index 48d122b669696..42bdcc7da03df 100644 --- a/providers/standard/src/airflow/providers/standard/version_compat.py +++ b/providers/standard/src/airflow/providers/standard/version_compat.py @@ -33,3 +33,4 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) +AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)