From 68013b8986caf12908b9176c0cce175b0d48d9e0 Mon Sep 17 00:00:00 2001 From: Kirilishin Aleksei Date: Wed, 30 Aug 2023 13:31:15 +0300 Subject: [PATCH] Fix issue #33164: count removed upstream tasks for non-mapped tasks. --- airflow/ti_deps/deps/trigger_rule_dep.py | 8 +-- tests/ti_deps/deps/test_trigger_rule_dep.py | 69 ++++++++++++++++++++- 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index 9731d2a7c235e..8f4465afc55c3 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -332,9 +332,7 @@ def _iter_upstream_conditions() -> Iterator[ColumnOperators]: ) ) elif trigger_rule == TR.ALL_SUCCESS: - num_failures = upstream - success - if ti.map_index > -1: - num_failures -= removed + num_failures = upstream - success - removed if num_failures > 0: yield self._failing_status( reason=( @@ -345,9 +343,7 @@ def _iter_upstream_conditions() -> Iterator[ColumnOperators]: ) ) elif trigger_rule == TR.ALL_FAILED: - num_success = upstream - failed - upstream_failed - if ti.map_index > -1: - num_success -= removed + num_success = upstream - failed - upstream_failed - removed if num_success > 0: yield self._failing_status( reason=( diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index faa70b5a4951d..ca5a02e286851 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -101,6 +101,7 @@ def do_something_else(i): with dag_maker(dag_id="test_dag"): nums = do_something.expand(i=[i + 1 for i in range(5)]) + nums >> EmptyOperator(task_id="do_something_non_mapped_task", trigger_rule=trigger_rule) do_something_else.expand(i=nums) dr = dag_maker.create_dagrun() @@ -112,7 +113,7 @@ def do_something_else(i): ti.dag_run = dr session.add(ti) session.flush() - tis = dr.get_task_instances() + tis = dr.get_task_instances()[:-1] for ti in tis: if ti.task_id == "do_something": if ti.map_index > 2: @@ -1029,6 +1030,39 @@ def test_mapped_task_upstream_removed_with_all_success_trigger_rules( assert len(dep_statuses) == 0 assert ti.state == TaskInstanceState.REMOVED + def test_nonmapped_task_upstream_removed_with_all_success_trigger_rules( + self, + monkeypatch, + session, + get_mapped_task_dagrun, + ): + """ + Test ALL_SUCCESS trigger rule with non-mapped task upstream removed + """ + upstream_states = _UpstreamTIStates( + success=3, + skipped=0, + failed=0, + removed=2, + upstream_failed=0, + done=5, + skipped_setup=0, + success_setup=0, + ) + monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states) + + dr, _ = get_mapped_task_dagrun() + ti = dr.get_task_instance(task_id="do_something_non_mapped_task", session=session) + ti.task = dr.dag.task_dict["do_something_non_mapped_task"] + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + dep_context=DepContext(flag_upstream_failed=True), + session=session, + ) + ) + assert len(dep_statuses) == 0 + def test_mapped_task_upstream_removed_with_all_failed_trigger_rules( self, monkeypatch, @@ -1067,6 +1101,39 @@ def test_mapped_task_upstream_removed_with_all_failed_trigger_rules( assert len(dep_statuses) == 0 + def test_nonmapped_task_upstream_removed_with_all_failed_trigger_rules( + self, + monkeypatch, + session, + get_mapped_task_dagrun, + ): + """ + Test ALL_FAILED trigger rule with non-mapped task upstream removed + """ + upstream_states = _UpstreamTIStates( + success=0, + skipped=0, + failed=3, + removed=2, + upstream_failed=0, + done=5, + skipped_setup=0, + success_setup=0, + ) + monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states) + + dr, _ = get_mapped_task_dagrun(trigger_rule=TriggerRule.ALL_FAILED, state=TaskInstanceState.FAILED) + ti = dr.get_task_instance(task_id="do_something_non_mapped_task", session=session) + ti.task = dr.dag.task_dict["do_something_non_mapped_task"] + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + dep_context=DepContext(flag_upstream_failed=False), + session=session, + ) + ) + assert len(dep_statuses) == 0 + @pytest.mark.parametrize( "trigger_rule", [TriggerRule.NONE_FAILED, TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS],