Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _get_states_count_upstream_ti(task, finished_tis):
counter.get(State.SKIPPED, 0),
counter.get(State.FAILED, 0),
counter.get(State.UPSTREAM_FAILED, 0),
counter.get(State.REMOVED, 0),
sum(counter.values()),
)

Expand All @@ -73,7 +74,7 @@ def _get_dep_statuses(self, ti, session, dep_context: DepContext):
yield self._passing_status(reason="The task had a always trigger rule set.")
return
# see if the task name is in the task upstream for our task
successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti(
successes, skipped, failed, upstream_failed, removed, done = self._get_states_count_upstream_ti(
task=ti.task, finished_tis=dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
)

Expand All @@ -83,6 +84,7 @@ def _get_dep_statuses(self, ti, session, dep_context: DepContext):
skipped=skipped,
failed=failed,
upstream_failed=upstream_failed,
removed=removed,
done=done,
flag_upstream_failed=dep_context.flag_upstream_failed,
dep_context=dep_context,
Expand Down Expand Up @@ -122,6 +124,7 @@ def _evaluate_trigger_rule(
skipped,
failed,
upstream_failed,
removed,
done,
flag_upstream_failed,
dep_context: DepContext,
Expand Down Expand Up @@ -152,6 +155,7 @@ def _evaluate_trigger_rule(
"successes": successes,
"skipped": skipped,
"failed": failed,
"removed": removed,
"upstream_failed": upstream_failed,
"done": done,
}
Expand All @@ -162,6 +166,9 @@ def _evaluate_trigger_rule(
changed = ti.set_state(State.UPSTREAM_FAILED, session)
elif skipped:
changed = ti.set_state(State.SKIPPED, session)
elif removed and successes and ti.map_index > -1:
if ti.map_index >= successes:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we compare map index with number of upstream successes? that seems odd?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yes now you mention it this feels like it's going to break in some other cases.
Like what if there is 1 mapped upstream which is in the failed state, one in the removed state, this would erroneously remove it I think?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ephraimbuddy Could you take another look at this PR/case please?

Copy link
Copy Markdown
Contributor Author

@ephraimbuddy ephraimbuddy Sep 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yes now you mention it this feels like it's going to break in some other cases. Like what if there is 1 mapped upstream which is in the failed state, one in the removed state, this would erroneously remove it I think?

In this case, successes will be 0, also failed=1, so the condition will not be reached and the taskinstance will be marked as upstream_failed. Same thing when we have skipped task instances. The condition to mark the task instance as removed will not be reached.

The condition for the task to be marked removed is if we have some removed task instances and successful task instances, no failed, no skipped and the task is mapped. So if we get here, if the map_index of the task instance is >= all successful task instances, it means the task instance upstream is removed because indexes go from -1 upwards, it's not possible to remove map_index 1 and still have map_index 3?

If we have 5 mapped tasks(0,1,2,3,4), and we remove 2, we will have 3 mapped tasks(0,1,2). If these 3 are successful,(successes=3), then the removed are those greater than or equal to the map index 3(3,4).

Copy link
Copy Markdown
Member

@ashb ashb Sep 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if a task has multiple upstreams?

[a, b] >> mapped_task(list_gen) for instance?

Edit: [a, b] >> mapped_task.map(list_gen) for instance?

And a is success, b is failure, and list_gen is reduced to only returning a single item?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would apply

if upstream_failed or failed:
changed = ti.set_state(State.UPSTREAM_FAILED, session)
it won't get to line 169.

Line 169 is only satisfied if we have removed, successes, no failed, no skipped and mapped task

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh good. This is important enough functionality (it's the very core of Airflow) that we should add atest cases covering things like this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does seem like it's covered here:

# Parameterized tests to check for the correct firing
# of the trigger_rule under various circumstances
# Numeric fields are in order:
# successes, skipped, failed, upstream_failed, done, removed
@pytest.mark.parametrize(
"trigger_rule,successes,skipped,failed,upstream_failed,done,removed,"
"flag_upstream_failed,expect_state,expect_completed",
[
#
# Tests for all_success
#
['all_success', 5, 0, 0, 0, 0, 0, True, None, True],
['all_success', 2, 0, 0, 0, 0, 0, True, None, False],
['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False],
#
# Tests for one_success
#
['one_success', 5, 0, 0, 0, 5, 0, True, None, True],
['one_success', 2, 0, 0, 0, 2, 0, True, None, True],
['one_success', 2, 0, 1, 0, 3, 0, True, None, True],
['one_success', 2, 1, 0, 0, 3, 0, True, None, True],
['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False],
['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, False],
#
# Tests for all_failed
#
['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False],
['all_failed', 0, 0, 5, 0, 5, 0, True, None, True],
['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False],
['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False],
#
# Tests for one_failed
#
['one_failed', 5, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 1, 0, 0, 0, True, None, True],
['one_failed', 2, 1, 0, 0, 3, 0, True, None, False],
['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False],
#
# Tests for done
#
['all_done', 5, 0, 0, 0, 5, 0, True, None, True],
['all_done', 2, 0, 0, 0, 2, 0, True, None, False],
['all_done', 2, 0, 1, 0, 3, 0, True, None, False],
['all_done', 2, 1, 0, 0, 3, 0, True, None, False],
],
)
def test_check_task_dependencies(
self,
trigger_rule: str,
successes: int,
skipped: int,
failed: int,
removed: int,
upstream_failed: int,
done: int,
flag_upstream_failed: bool,
expect_state: State,
expect_completed: bool,
dag_maker,
):
with dag_maker() as dag:
downstream = EmptyOperator(task_id="downstream", trigger_rule=trigger_rule)
for i in range(5):
task = EmptyOperator(task_id=f'runme_{i}', dag=dag)
task.set_downstream(downstream)
assert task.start_date is not None
run_date = task.start_date + datetime.timedelta(days=5)
ti = dag_maker.create_dagrun(execution_date=run_date).get_task_instance(downstream.task_id)
ti.task = downstream
dep_results = TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=successes,
skipped=skipped,
failed=failed,
removed=removed,
upstream_failed=upstream_failed,
done=done,
dep_context=DepContext(),
flag_upstream_failed=flag_upstream_failed,
)
completed = all(dep.passed for dep in dep_results)
assert completed == expect_completed
assert ti.state == expect_state
# Parameterized tests to check for the correct firing
# of the trigger_rule under various circumstances of mapped task
# Numeric fields are in order:
# successes, skipped, failed, upstream_failed, done,removed
@pytest.mark.parametrize(
"trigger_rule,successes,skipped,failed,upstream_failed,done,removed,"
"flag_upstream_failed,expect_state,expect_completed",
[
#
# Tests for all_success
#
['all_success', 5, 0, 0, 0, 0, 0, True, None, True],
['all_success', 2, 0, 0, 0, 0, 0, True, None, False],
['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False],
['all_success', 3, 0, 0, 0, 0, 2, True, State.REMOVED, True], # ti.map_index >=successes
#
# Tests for one_success
#
['one_success', 5, 0, 0, 0, 5, 0, True, None, True],
['one_success', 2, 0, 0, 0, 2, 0, True, None, True],
['one_success', 2, 0, 1, 0, 3, 0, True, None, True],
['one_success', 2, 1, 0, 0, 3, 0, True, None, True],
['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False],
['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, False],
#
# Tests for all_failed
#
['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False],
['all_failed', 0, 0, 5, 0, 5, 0, True, None, True],
['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False],
['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 4, 1, True, State.SKIPPED, False], # One removed
#
# Tests for one_failed
#
['one_failed', 5, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 1, 0, 0, 0, True, None, True],
['one_failed', 2, 1, 0, 0, 3, 0, True, None, False],
['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False],
['one_failed', 2, 2, 0, 0, 5, 1, True, State.SKIPPED, False], # One removed
#
# Tests for done
#
['all_done', 5, 0, 0, 0, 5, 0, True, None, True],
['all_done', 2, 0, 0, 0, 2, 0, True, None, False],
['all_done', 2, 0, 1, 0, 3, 0, True, None, False],
['all_done', 2, 1, 0, 0, 3, 0, True, None, False],
],
)
def test_check_task_dependencies_for_mapped(
, I'm taking a look too

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is removing this part altogether. It's not part of the deadlock issue but I feel that it's good to have stuff.
My reason is this:
If at first run upstream was 3 and downstream was 3 too. Upstream created the downstream. We have 3 -> 3 successes.
Then we reduce upstream to 2, meaning one task is removed and we clear and rerun the dag, without this part of the change, we will end up running all 3 of the downstreams: upstream (2 successful, 1 removed). Downstream(3 successful)

Copy link
Copy Markdown
Member

@ashb ashb Sep 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not this one specific line, they are all like this and that's the worry.

The test you highlighted doesn't use mapped tasks so I don't think it covers the case I highlighted. Edit: sorry, original example didn't have a map. Added that.

changed = ti.set_state(State.REMOVED, session)
elif trigger_rule == TR.ALL_FAILED:
if successes or skipped:
changed = ti.set_state(State.SKIPPED, session)
Expand Down Expand Up @@ -189,6 +196,7 @@ def _evaluate_trigger_rule(
elif trigger_rule == TR.ALL_SKIPPED:
if successes or failed:
changed = ti.set_state(State.SKIPPED, session)

if changed:
dep_context.have_changed_ti_states = True

Expand All @@ -212,6 +220,8 @@ def _evaluate_trigger_rule(
)
elif trigger_rule == TR.ALL_SUCCESS:
num_failures = upstream - successes
if ti.map_index > -1:
num_failures -= removed
if num_failures > 0:
yield self._failing_status(
reason=(
Expand All @@ -223,6 +233,8 @@ def _evaluate_trigger_rule(
)
elif trigger_rule == TR.ALL_FAILED:
num_successes = upstream - failed - upstream_failed
if ti.map_index > -1:
num_successes -= removed
if num_successes > 0:
yield self._failing_status(
reason=(
Expand All @@ -244,6 +256,8 @@ def _evaluate_trigger_rule(
)
elif trigger_rule == TR.NONE_FAILED:
num_failures = upstream - successes - skipped
if ti.map_index > -1:
num_failures -= removed
if num_failures > 0:
yield self._failing_status(
reason=(
Expand All @@ -255,6 +269,8 @@ def _evaluate_trigger_rule(
)
elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
num_failures = upstream - successes - skipped
if ti.map_index > -1:
num_failures -= removed
if num_failures > 0:
yield self._failing_status(
reason=(
Expand Down
40 changes: 40 additions & 0 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,3 +1905,43 @@ def say_hi():
dr.update_state(session=session)
assert dr.state == DagRunState.SUCCESS
assert tis['add_one__1'].state == TaskInstanceState.SKIPPED


def test_schedulable_task_exist_when_rerun_removed_upstream_mapped_task(session, dag_maker):
from airflow.decorators import task

@task
def do_something(i):
return 1

@task
def do_something_else(i):
return 1

with dag_maker():
nums = do_something.expand(i=[i + 1 for i in range(5)])
do_something_else.expand(i=nums)

dr = dag_maker.create_dagrun()

ti = dr.get_task_instance('do_something_else', session=session)
ti.map_index = 0
task = ti.task
for map_index in range(1, 5):
ti = TI(task, run_id=dr.run_id, map_index=map_index)
ti.dag_run = dr
session.add(ti)
session.flush()
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == 'do_something':
if ti.map_index > 2:
ti.state = TaskInstanceState.REMOVED
else:
ti.state = TaskInstanceState.SUCCESS
session.merge(ti)
session.commit()
# The Upstream is done with 2 removed tis and 3 success tis
(tis, _) = dr.update_state()
assert len(tis)
assert dr.state != DagRunState.FAILED
178 changes: 147 additions & 31 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,55 +1065,55 @@ def test_depends_on_past(self, dag_maker):
# Parameterized tests to check for the correct firing
# of the trigger_rule under various circumstances
# Numeric fields are in order:
# successes, skipped, failed, upstream_failed, done
# successes, skipped, failed, upstream_failed, done, removed
@pytest.mark.parametrize(
"trigger_rule,successes,skipped,failed,upstream_failed,done,"
"trigger_rule,successes,skipped,failed,upstream_failed,done,removed,"
"flag_upstream_failed,expect_state,expect_completed",
[
#
# Tests for all_success
#
['all_success', 5, 0, 0, 0, 0, True, None, True],
['all_success', 2, 0, 0, 0, 0, True, None, False],
['all_success', 2, 0, 1, 0, 0, True, State.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 0, True, State.SKIPPED, False],
['all_success', 5, 0, 0, 0, 0, 0, True, None, True],
['all_success', 2, 0, 0, 0, 0, 0, True, None, False],
['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False],
#
# Tests for one_success
#
['one_success', 5, 0, 0, 0, 5, True, None, True],
['one_success', 2, 0, 0, 0, 2, True, None, True],
['one_success', 2, 0, 1, 0, 3, True, None, True],
['one_success', 2, 1, 0, 0, 3, True, None, True],
['one_success', 0, 5, 0, 0, 5, True, State.SKIPPED, False],
['one_success', 0, 4, 1, 0, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 3, 1, 1, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 4, 0, 1, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 5, 0, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 4, 1, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 0, 5, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 5, 0, 0, 0, 5, 0, True, None, True],
['one_success', 2, 0, 0, 0, 2, 0, True, None, True],
['one_success', 2, 0, 1, 0, 3, 0, True, None, True],
['one_success', 2, 1, 0, 0, 3, 0, True, None, True],
['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False],
['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, False],
#
# Tests for all_failed
#
['all_failed', 5, 0, 0, 0, 5, True, State.SKIPPED, False],
['all_failed', 0, 0, 5, 0, 5, True, None, True],
['all_failed', 2, 0, 0, 0, 2, True, State.SKIPPED, False],
['all_failed', 2, 0, 1, 0, 3, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 3, True, State.SKIPPED, False],
['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False],
['all_failed', 0, 0, 5, 0, 5, 0, True, None, True],
['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False],
['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False],
#
# Tests for one_failed
#
['one_failed', 5, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 1, 0, 0, True, None, True],
['one_failed', 2, 1, 0, 0, 3, True, None, False],
['one_failed', 2, 3, 0, 0, 5, True, State.SKIPPED, False],
['one_failed', 5, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 1, 0, 0, 0, True, None, True],
['one_failed', 2, 1, 0, 0, 3, 0, True, None, False],
['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False],
#
# Tests for done
#
['all_done', 5, 0, 0, 0, 5, True, None, True],
['all_done', 2, 0, 0, 0, 2, True, None, False],
['all_done', 2, 0, 1, 0, 3, True, None, False],
['all_done', 2, 1, 0, 0, 3, True, None, False],
['all_done', 5, 0, 0, 0, 5, 0, True, None, True],
['all_done', 2, 0, 0, 0, 2, 0, True, None, False],
['all_done', 2, 0, 1, 0, 3, 0, True, None, False],
['all_done', 2, 1, 0, 0, 3, 0, True, None, False],
],
)
def test_check_task_dependencies(
Expand All @@ -1122,6 +1122,7 @@ def test_check_task_dependencies(
successes: int,
skipped: int,
failed: int,
removed: int,
upstream_failed: int,
done: int,
flag_upstream_failed: bool,
Expand All @@ -1144,6 +1145,121 @@ def test_check_task_dependencies(
successes=successes,
skipped=skipped,
failed=failed,
removed=removed,
upstream_failed=upstream_failed,
done=done,
dep_context=DepContext(),
flag_upstream_failed=flag_upstream_failed,
)
completed = all(dep.passed for dep in dep_results)

assert completed == expect_completed
assert ti.state == expect_state

# Parameterized tests to check for the correct firing
# of the trigger_rule under various circumstances of mapped task
# Numeric fields are in order:
# successes, skipped, failed, upstream_failed, done,removed
@pytest.mark.parametrize(
"trigger_rule,successes,skipped,failed,upstream_failed,done,removed,"
"flag_upstream_failed,expect_state,expect_completed",
[
#
# Tests for all_success
#
['all_success', 5, 0, 0, 0, 0, 0, True, None, True],
['all_success', 2, 0, 0, 0, 0, 0, True, None, False],
['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False],
['all_success', 3, 0, 0, 0, 0, 2, True, State.REMOVED, True], # ti.map_index >=successes
#
# Tests for one_success
#
['one_success', 5, 0, 0, 0, 5, 0, True, None, True],
['one_success', 2, 0, 0, 0, 2, 0, True, None, True],
['one_success', 2, 0, 1, 0, 3, 0, True, None, True],
['one_success', 2, 1, 0, 0, 3, 0, True, None, True],
['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False],
['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, False],
#
# Tests for all_failed
#
['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False],
['all_failed', 0, 0, 5, 0, 5, 0, True, None, True],
['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False],
['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 4, 1, True, State.SKIPPED, False], # One removed
#
# Tests for one_failed
#
['one_failed', 5, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 1, 0, 0, 0, True, None, True],
['one_failed', 2, 1, 0, 0, 3, 0, True, None, False],
['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False],
['one_failed', 2, 2, 0, 0, 5, 1, True, State.SKIPPED, False], # One removed
#
# Tests for done
#
['all_done', 5, 0, 0, 0, 5, 0, True, None, True],
['all_done', 2, 0, 0, 0, 2, 0, True, None, False],
['all_done', 2, 0, 1, 0, 3, 0, True, None, False],
['all_done', 2, 1, 0, 0, 3, 0, True, None, False],
],
)
def test_check_task_dependencies_for_mapped(
self,
trigger_rule: str,
successes: int,
skipped: int,
failed: int,
removed: int,
upstream_failed: int,
done: int,
flag_upstream_failed: bool,
expect_state: State,
expect_completed: bool,
dag_maker,
session,
):
from airflow.decorators import task

@task
def do_something(i):
return 1

@task(trigger_rule=trigger_rule)
def do_something_else(i):
return 1

with dag_maker(dag_id='test_dag'):
nums = do_something.expand(i=[i + 1 for i in range(5)])
do_something_else.expand(i=nums)

dr = dag_maker.create_dagrun()

ti = dr.get_task_instance('do_something_else', session=session)
ti.map_index = 0
for map_index in range(1, 5):
ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index)
ti.dag_run = dr
session.add(ti)
session.flush()
downstream = ti.task
ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session)
ti.task = downstream
dep_results = TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=successes,
skipped=skipped,
failed=failed,
removed=removed,
upstream_failed=upstream_failed,
done=done,
dep_context=DepContext(),
Expand Down
Loading