Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions airflow-core/tests/unit/core/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
import pytest

from airflow.exceptions import AirflowTaskTimeout
from airflow.models import TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.utils.timezone import datetime
Expand All @@ -49,8 +47,15 @@ def teardown_method(self):
self.clean_db()

def test_dryrun(self, dag_maker):
class TemplateFieldOperator(BaseOperator):
template_fields = ["bash_command"]

def __init__(self, bash_command, **kwargs):
self.bash_command = bash_command
super().__init__(**kwargs)

with dag_maker():
op = BashOperator(task_id="test_dryrun", bash_command="echo success")
op = TemplateFieldOperator(task_id="test_dryrun", bash_command="echo success")
dag_maker.create_dagrun()
op.dry_run()

Expand Down Expand Up @@ -81,17 +86,15 @@ def sleep_and_catch_other_exceptions():
execution_timeout=timedelta(seconds=1),
python_callable=sleep_and_catch_other_exceptions,
)
dag_maker.create_dagrun()
with pytest.raises(AirflowTaskTimeout):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
dag_maker.run_ti(op.task_id)

def test_dag_params_and_task_params(self, dag_maker):
# This test case guards how params of DAG and Operator work together.
# - If any key exists in either DAG's or Operator's params,
# it is guaranteed to be available eventually.
# - If any key exists in both DAG's params and Operator's params,
# the latter has precedence.
TI = TaskInstance

with dag_maker(
schedule=timedelta(weeks=1),
Expand All @@ -106,12 +109,9 @@ def test_dag_params_and_task_params(self, dag_maker):
dr = dag_maker.create_dagrun(
run_type=DagRunType.SCHEDULED,
)
task1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
task2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
ti1 = TI(task=task1, run_id=dr.run_id)
ti2 = TI(task=task2, run_id=dr.run_id)
ti1.refresh_from_db()
ti2.refresh_from_db()
ti1 = dag_maker.run_ti(task1.task_id, dr)
ti2 = dag_maker.run_ti(task2.task_id, dr)

context1 = ti1.get_template_context()
context2 = ti2.get_template_context()

Expand Down
9 changes: 7 additions & 2 deletions airflow-core/tests/unit/models/test_cleartasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest
from sqlalchemy import select

from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance, TaskInstance as TI, clear_task_instances
Expand Down Expand Up @@ -677,9 +678,13 @@ def _get_ti(old_ti):
assert ti.max_tries == 1

def test_operator_clear(self, dag_maker, session):
class ClearOperator(BaseOperator):
def execute(self, context):
pass

with dag_maker("test_operator_clear"):
op1 = EmptyOperator(task_id="test1")
op2 = EmptyOperator(task_id="test2", retries=1)
op1 = ClearOperator(task_id="test1")
op2 = ClearOperator(task_id="test2", retries=1)
op1 >> op2

dr = dag_maker.create_dagrun(
Expand Down
68 changes: 37 additions & 31 deletions airflow-core/tests/unit/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
AirflowSkipException,
)
from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel
from airflow.models.baseoperator import BaseOperator
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
Expand Down Expand Up @@ -775,8 +776,11 @@ def run_ti_and_assert(
assert not task_reschedules_for_ti(ti)

def test_depends_on_past_catchup_true(self, dag_maker):
class CustomOp(BaseOperator):
def execute(self, context): ...

with dag_maker(dag_id="test_depends_on_past", serialized=True, catchup=True):
task = EmptyOperator(
task = CustomOp(
task_id="test_dop_task",
depends_on_past=True,
)
Expand Down Expand Up @@ -806,8 +810,11 @@ def test_depends_on_past_catchup_true(self, dag_maker):
assert ti.state == State.SUCCESS

def test_depends_on_past_catchup_false(self, dag_maker):
class CustomOp(BaseOperator):
def execute(self, context): ...

with dag_maker(dag_id="test_depends_on_past_catchup_false", serialized=True, catchup=False):
task = EmptyOperator(
task = CustomOp(
task_id="test_dop_task",
depends_on_past=True,
)
Expand Down Expand Up @@ -2281,7 +2288,10 @@ def test_template_with_json_variable_missing(self, create_task_instance, session
ti.task.render_template('{{ var.json.get("missing_variable") }}', context)

@provide_session
def test_handle_failure(self, create_dummy_dag, session=None):
def test_handle_failure(self, dag_maker, session=None):
class CustomOp(BaseOperator):
def execute(self, context): ...

start_date = timezone.datetime(2016, 6, 1)
clear_db_runs()

Expand All @@ -2298,16 +2308,15 @@ def test_handle_failure(self, create_dummy_dag, session=None):
__name__="mock_on_retry_1",
__call__=mock.MagicMock(),
)
dag, task1 = create_dummy_dag(
dag_id="test_handle_failure",
schedule=None,
start_date=start_date,
task_id="test_handle_failure_on_failure",
with_dagrun_type=DagRunType.MANUAL,
on_failure_callback=mock_on_failure_1,
on_retry_callback=mock_on_retry_1,
session=session,
)
with dag_maker(dag_id="test_handle_failure", start_date=start_date, schedule=None) as dag:
task1 = CustomOp(
task_id="test_handle_failure_on_failure",
on_failure_callback=mock_on_failure_1,
on_retry_callback=mock_on_retry_1,
)

dag_maker.create_dagrun(session=session, run_type=DagRunType.MANUAL, start_date=start_date)

logical_date = timezone.utcnow()
dr = dag.create_dagrun(
run_id="test2",
Expand Down Expand Up @@ -2347,7 +2356,7 @@ def test_handle_failure(self, create_dummy_dag, session=None):
__name__="mock_on_retry_2",
__call__=mock.MagicMock(),
)
task2 = EmptyOperator(
task2 = CustomOp(
task_id="test_handle_failure_on_retry",
on_failure_callback=mock_on_failure_2,
on_retry_callback=mock_on_retry_2,
Expand Down Expand Up @@ -2375,7 +2384,7 @@ def test_handle_failure(self, create_dummy_dag, session=None):
__name__="mock_on_retry_3",
__call__=mock.MagicMock(),
)
task3 = EmptyOperator(
task3 = CustomOp(
task_id="test_handle_failure_on_force_fail",
on_failure_callback=mock_on_failure_3,
on_retry_callback=mock_on_retry_3,
Expand Down Expand Up @@ -2450,20 +2459,23 @@ def test_handle_failure_task_undefined(self, create_task_instance):
del ti.task
ti.handle_failure("test ti.task undefined")

def test_handle_failure_fail_fast(self, create_dummy_dag, session):
def test_handle_failure_fail_fast(self, dag_maker, session):
start_date = timezone.datetime(2016, 6, 1)
clear_db_runs()

dag, task1 = create_dummy_dag(
class CustomOp(BaseOperator):
def execute(self, context): ...

with dag_maker(
dag_id="test_handle_failure_fail_fast",
schedule=None,
start_date=start_date,
task_id="task1",
trigger_rule="all_success",
with_dagrun_type=DagRunType.MANUAL,
session=session,
schedule=None,
fail_fast=True,
)
) as dag:
task1 = CustomOp(task_id="task1", trigger_rule="all_success")

dag_maker.create_dagrun(run_type=DagRunType.MANUAL, start_date=start_date)

logical_date = timezone.utcnow()
dr = dag.create_dagrun(
run_id="test_ff",
Expand All @@ -2484,19 +2496,13 @@ def test_handle_failure_fail_fast(self, create_dummy_dag, session):
states = [State.RUNNING, State.FAILED, State.QUEUED, State.SCHEDULED, State.DEFERRED]
tasks = []
for i, state in enumerate(states):
op = EmptyOperator(
task_id=f"reg_Task{i}",
dag=dag,
)
op = CustomOp(task_id=f"reg_Task{i}", dag=dag)
ti = TI(task=op, run_id=dr.run_id)
ti.state = state
session.add(ti)
tasks.append(ti)

fail_task = EmptyOperator(
task_id="fail_Task",
dag=dag,
)
fail_task = CustomOp(task_id="fail_Task", dag=dag)
ti_ff = TI(task=fail_task, run_id=dr.run_id)
ti_ff.state = State.FAILED
session.add(ti_ff)
Expand Down
51 changes: 51 additions & 0 deletions devel-common/src/tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,14 @@ def create_dagrun(

def create_dagrun_after(self, dagrun: DagRun, **kwargs) -> DagRun: ...

def run_ti(
self,
task_id: str,
dag_run: DagRun | None = ...,
dag_run_kwargs: dict | None = ...,
**kwargs,
) -> TaskInstance: ...

def __call__(
self,
dag_id: str = "test_dag",
Expand Down Expand Up @@ -1100,6 +1108,49 @@ def create_dagrun_after(self, dagrun, **kwargs):
**kwargs,
)

def run_ti(self, task_id, dag_run=None, dag_run_kwargs=None, **kwargs):
"""
Create a dagrun and run a specific task instance with proper task refresh.

This is a convenience method for running a single task instance:
1. Create a dagrun if it does not exist
2. Get the specific task instance by task_id
3. Refresh the task instance from the DAG task
4. Run the task instance

Returns the created TaskInstance.
"""
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS

if dag_run is None:
if dag_run_kwargs is None:
dag_run_kwargs = {}
dag_run = self.create_dagrun(**dag_run_kwargs)
ti = dag_run.get_task_instance(task_id=task_id)
if ti is None:
available_task_ids = [task.task_id for task in self.dag.tasks]
raise ValueError(
f"Task instance with task_id '{task_id}' not found in dag run. "
f"Available task_ids: {available_task_ids}"
)
task = self.dag.get_task(ti.task_id)

if not AIRFLOW_V_3_1_PLUS:
# Airflow <3.1 has a bug for DecoratedOperator has an unused signature for
# `DecoratedOperator._handle_output` for xcom_push
# This worked for `models.BaseOperator` since it had xcom_push method but for
# `airflow.sdk.BaseOperator`, this does not exist, so this returns an AttributeError
# Since this is an unused attribute anyway, we just monkey patch it with a lambda.
# Error otherwise:
# /usr/local/lib/python3.11/site-packages/airflow/sdk/bases/decorator.py:253: in execute
# return self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push)
# ^^^^^^^^^^^^^^
# E AttributeError: '_PythonDecoratedOperator' object has no attribute 'xcom_push'
task.xcom_push = lambda *args, **kwargs: None
ti.refresh_from_task(task)
ti.run(**kwargs)
return ti

def sync_dagbag_to_db(self):
if not AIRFLOW_V_3_0_PLUS:
self.dagbag.sync_to_db()
Expand Down
7 changes: 1 addition & 6 deletions providers/amazon/tests/system/amazon/aws/utils/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,10 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.providers.standard.operators.bash import BashOperator

if TYPE_CHECKING:
from airflow.models.operator import Operator


def get_describe_pod_operator(cluster_name: str, pod_name: str) -> Operator:
def get_describe_pod_operator(cluster_name: str, pod_name: str) -> BashOperator:
"""Returns an operator that'll print the output of a `k describe pod` in the airflow logs."""
return BashOperator(
task_id="describe_pod",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ def f(spark, sc):
return [random.random() for _ in range(100)]

with dag_maker():
ret = f()
f()

dr = dag_maker.create_dagrun()
ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date)
ti = dr.get_task_instances()[0]
ti.run()
assert len(ti.xcom_pull()) == 100
assert config.get("spark.master") == "spark://none"
assert config.get("spark.executor.memory") == "2g"
Expand All @@ -130,11 +130,11 @@ def f():
return e

with dag_maker():
ret = f()
f()

dr = dag_maker.create_dagrun()
ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date)
ti = dr.get_task_instances()[0]
ti.run()
assert ti.xcom_pull() == e
assert config.get("spark.master") == "local[*]"
spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
Expand All @@ -154,11 +154,11 @@ def f(spark, sc):
return True

with dag_maker():
ret = f()
f()

dr = dag_maker.create_dagrun()
ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date)
ti = dr.get_task_instances()[0]
ti.run()
assert ti.xcom_pull()
assert config.get("spark.remote") == "sc://localhost"
assert config.get("spark.master") is None
Expand All @@ -180,11 +180,11 @@ def f(spark, sc):
return True

with dag_maker():
ret = f()
f()

dr = dag_maker.create_dagrun()
ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date)
ti = dr.get_task_instances()[0]
ti.run()
assert ti.xcom_pull()
assert config.get("spark.remote") == "sc://localhost/;user_id=connect;token=1234;use_ssl=True"
assert config.get("spark.master") is None
Expand Down
Loading
Loading