Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from airflow.models.taskinstance import PAST_DEPENDS_MET
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule as TR

if TYPE_CHECKING:
Expand Down Expand Up @@ -63,8 +64,7 @@ def calculate(cls, finished_upstreams: Iterator[TaskInstance]) -> _UpstreamTISta
``counter`` is inclusive of ``setup_counter`` -- e.g. if there are 2 skipped upstreams, one
of which is a setup, then counter will show 2 skipped and setup counter will show 1.

:param ti: the ti that we want to calculate deps for
:param finished_tis: all the finished tasks of the dag_run
:param finished_upstreams: all the finished upstreams of the dag_run
"""
counter: dict[str, int] = Counter()
setup_counter: dict[str, int] = Counter()
Expand Down Expand Up @@ -143,6 +143,19 @@ def _get_expanded_ti_count() -> int:

return ti.task.get_mapped_ti_count(ti.run_id, session=session)

def _iter_expansion_dependencies(task_group: MappedTaskGroup) -> Iterator[str]:
from airflow.models.mappedoperator import MappedOperator

if isinstance(ti.task, MappedOperator):
for op in ti.task.iter_mapped_dependencies():
yield op.task_id
if task_group and task_group.iter_mapped_task_groups():
yield from (
op.task_id
for tg in task_group.iter_mapped_task_groups()
for op in tg.iter_mapped_dependencies()
)

@functools.lru_cache
def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
"""
Expand All @@ -156,6 +169,13 @@ def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
assert ti.task
assert isinstance(ti.task.dag, DAG)

if isinstance(ti.task.task_group, MappedTaskGroup):
is_fast_triggered = ti.task.trigger_rule in (TR.ONE_SUCCESS, TR.ONE_FAILED, TR.ONE_DONE)
if is_fast_triggered and upstream_id not in set(
_iter_expansion_dependencies(task_group=ti.task.task_group)
):
return None

try:
expanded_ti_count = _get_expanded_ti_count()
except (NotFullyPopulated, NotMapped):
Expand Down Expand Up @@ -217,7 +237,7 @@ def _iter_upstream_conditions(relevant_tasks: dict) -> Iterator[ColumnOperators]
for upstream_id in relevant_tasks:
map_indexes = _get_relevant_upstream_map_indexes(upstream_id)
if map_indexes is None: # All tis of this upstream are dependencies.
yield (TaskInstance.task_id == upstream_id)
yield TaskInstance.task_id == upstream_id
continue
# At this point we know we want to depend on only selected tis
# of this upstream task. Since the upstream may not have been
Expand All @@ -237,11 +257,9 @@ def _iter_upstream_conditions(relevant_tasks: dict) -> Iterator[ColumnOperators]

def _evaluate_setup_constraint(*, relevant_setups) -> Iterator[tuple[TIDepStatus, bool]]:
"""
Evaluate whether ``ti``'s trigger rule was met.
Evaluate whether ``ti``'s trigger rule was met as part of the setup constraint.

:param ti: Task instance to evaluate the trigger rule of.
:param dep_context: The current dependency context.
:param session: Database session.
:param relevant_setups: Relevant setups for the current task instance.
"""
if TYPE_CHECKING:
assert ti.task
Expand Down Expand Up @@ -327,13 +345,7 @@ def _evaluate_setup_constraint(*, relevant_setups) -> Iterator[tuple[TIDepStatus
)

def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
"""
Evaluate whether ``ti``'s trigger rule was met.

:param ti: Task instance to evaluate the trigger rule of.
:param dep_context: The current dependency context.
:param session: Database session.
"""
"""Evaluate whether ``ti``'s trigger rule in direct relatives was met."""
if TYPE_CHECKING:
assert ti.task

Expand Down Expand Up @@ -433,7 +445,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
)
if not past_depends_met:
yield self._failing_status(
reason=("Task should be skipped but the past depends are not met")
reason="Task should be skipped but the past depends are not met"
)
return
changed = ti.set_state(new_state, session)
Expand Down
1 change: 1 addition & 0 deletions newsfragments/44937.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix pre-mature evaluation of tasks in mapped task group. The origins of the bug are in ``TriggerRuleDep``, when dealing with ``TriggerRule`` that is fastly triggered (i.e, ``ONE_FAILED``, ``ONE_SUCCESS`, or ``ONE_DONE``). Please note that at time of merging, this fix has been applied only for Airflow version > 2.10.4 and < 3, and should be ported to v3 after merging PR #40460.
86 changes: 85 additions & 1 deletion tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from airflow.models.taskmap import TaskMap
from airflow.models.xcom_arg import XComArg
from airflow.operators.python import PythonOperator
from airflow.utils.state import TaskInstanceState
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.task_instance_session import set_current_task_instance_session
from airflow.utils.trigger_rule import TriggerRule
Expand Down Expand Up @@ -1784,3 +1784,87 @@ def group(n: int) -> None:
"group.last": {0: "success", 1: "skipped", 2: "success"},
}
assert states == expected


def test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete(dag_maker, session):
"""Test that one failed trigger rule works well in mapped task group"""
with dag_maker() as dag:

@dag.task
def t1():
return [1, 2, 3]

@task_group("tg1")
def tg1(a):
@dag.task()
def t2(a):
return a

@dag.task(trigger_rule=TriggerRule.ONE_FAILED)
def t3(a):
return a

t2(a) >> t3(a)

t = t1()
tg1.expand(a=t)

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task_id="t1")
ti.run()
dr.task_instance_scheduling_decisions()
ti3 = dr.get_task_instance(task_id="tg1.t3")
assert not ti3.state


def test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete__mapped_skip_with_all_success(
dag_maker, session
):
with dag_maker():

@task
def make_list():
return [4, 42, 2]

@task
def double(n):
if n == 42:
raise AirflowSkipException("42")
return n * 2

@task
def last(n):
print(n)

@task_group
def group(n: int) -> None:
last(double(n))

list = make_list()
group.expand(n=list)

dr = dag_maker.create_dagrun()

def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]:
decision = dr.task_instance_scheduling_decisions(session=session)
return {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}

tis = _one_scheduling_decision_iteration()
tis["make_list", -1].run()
assert tis["make_list", -1].state == State.SUCCESS

tis = _one_scheduling_decision_iteration()
tis["group.double", 0].run()
tis["group.double", 1].run()
tis["group.double", 2].run()

assert tis["group.double", 0].state == State.SUCCESS
assert tis["group.double", 1].state == State.SKIPPED
assert tis["group.double", 2].state == State.SUCCESS

tis = _one_scheduling_decision_iteration()
tis["group.last", 0].run()
tis["group.last", 2].run()
assert tis["group.last", 0].state == State.SUCCESS
assert dr.get_task_instance("group.last", map_index=1, session=session).state == State.SKIPPED
assert tis["group.last", 2].state == State.SUCCESS