Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ class derived from this one results in the creation of a task object,
|experimental|
:param trigger_rule: defines the rule by which dependencies are applied
for the task to get triggered. Options are:
``{ all_success | all_failed | all_done | all_skipped | one_success |
``{ all_success | all_failed | all_done | all_skipped | one_success | one_done |
one_failed | none_failed | none_failed_min_one_success | none_skipped | always}``
default is ``all_success``. Options can be set as string or
using the constants defined in the static class
Expand Down
13 changes: 13 additions & 0 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def _evaluate_trigger_rule(
elif trigger_rule == TR.ONE_FAILED:
if upstream_done and not (failed or upstream_failed):
changed = ti.set_state(State.SKIPPED, session)
elif trigger_rule == TR.ONE_DONE:
if upstream_done and not (failed or successes):
changed = ti.set_state(State.SKIPPED, session)
elif trigger_rule == TR.NONE_FAILED:
if upstream_failed or failed:
changed = ti.set_state(State.UPSTREAM_FAILED, session)
Expand Down Expand Up @@ -209,6 +212,16 @@ def _evaluate_trigger_rule(
f"upstream_task_ids={task.upstream_task_ids}"
)
)
elif trigger_rule == TR.ONE_DONE:
if successes + failed <= 0:
yield self._failing_status(
reason=(
f"Task's trigger rule '{trigger_rule}'"
"requires at least one upstream task failure or success"
f"but none were failed or success. upstream_tasks_state={upstream_tasks_state}, "
f"upstream_task_ids={task.upstream_task_ids}"
)
)
elif trigger_rule == TR.ALL_SUCCESS:
num_failures = upstream - successes
if num_failures > 0:
Expand Down
1 change: 1 addition & 0 deletions airflow/utils/trigger_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TriggerRule(str, Enum):
ALL_DONE = 'all_done'
ONE_SUCCESS = 'one_success'
ONE_FAILED = 'one_failed'
ONE_DONE = 'one_done'
NONE_FAILED = 'none_failed'
NONE_FAILED_OR_SKIPPED = 'none_failed_or_skipped'
NONE_SKIPPED = 'none_skipped'
Expand Down
1 change: 1 addition & 0 deletions docs/apache-airflow/concepts/dags.rst
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ However, this is just the default behaviour, and you can control it using the ``
* ``all_skipped``: All upstream tasks are in a ``skipped`` state
* ``one_failed``: At least one upstream task has failed (does not wait for all upstream tasks to be done)
* ``one_success``: At least one upstream task has succeeded (does not wait for all upstream tasks to be done)
* ``one_done``: At least one upstream task succeeded or failed
* ``none_failed``: All upstream tasks have not ``failed`` or ``upstream_failed`` - that is, all upstream tasks have succeeded or been skipped
* ``none_failed_min_one_success``: All upstream tasks have not ``failed`` or ``upstream_failed``, and at least one upstream task has succeeded.
* ``none_skipped``: No upstream task is in a ``skipped`` state - that is, all upstream tasks are in a ``success``, ``failed``, or ``upstream_failed`` state
Expand Down
77 changes: 77 additions & 0 deletions tests/ti_deps/deps/test_trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,83 @@ def test_one_failure_tr_success(self, get_task_instance):
)
assert len(dep_statuses) == 0

def test_one_done_tr_success(self, get_task_instance):
"""
One-done trigger rule success
"""
ti = get_task_instance(TriggerRule.ONE_DONE)
dep_statuses = tuple(
TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=2,
skipped=0,
failed=0,
upstream_failed=0,
done=2,
flag_upstream_failed=False,
dep_context=DepContext(),
session="Fake Session",
)
)
assert len(dep_statuses) == 0

dep_statuses = tuple(
TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=0,
skipped=0,
failed=2,
upstream_failed=0,
done=2,
flag_upstream_failed=False,
dep_context=DepContext(),
session="Fake Session",
)
)
assert len(dep_statuses) == 0

def test_one_done_tr_skip(self, get_task_instance):
"""
One-done trigger rule skip
"""
ti = get_task_instance(TriggerRule.ONE_DONE)
dep_statuses = tuple(
TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=0,
skipped=2,
failed=0,
upstream_failed=0,
done=2,
flag_upstream_failed=False,
dep_context=DepContext(),
session="Fake Session",
)
)
assert len(dep_statuses) == 1
assert not dep_statuses[0].passed

def test_one_done_tr_upstream_failed(self, get_task_instance):
"""
One-done trigger rule upstream_failed
"""
ti = get_task_instance(TriggerRule.ONE_DONE)
dep_statuses = tuple(
TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=0,
skipped=0,
failed=0,
upstream_failed=2,
done=2,
flag_upstream_failed=False,
dep_context=DepContext(),
session="Fake Session",
)
)
assert len(dep_statuses) == 1
assert not dep_statuses[0].passed

def test_all_success_tr_success(self, get_task_instance):
"""
All-success trigger rule success
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_trigger_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ def test_valid_trigger_rules(self):
assert TriggerRule.is_valid(TriggerRule.ALL_SKIPPED)
assert TriggerRule.is_valid(TriggerRule.ONE_SUCCESS)
assert TriggerRule.is_valid(TriggerRule.ONE_FAILED)
assert TriggerRule.is_valid(TriggerRule.ONE_DONE)
assert TriggerRule.is_valid(TriggerRule.NONE_FAILED)
assert TriggerRule.is_valid(TriggerRule.NONE_FAILED_OR_SKIPPED)
assert TriggerRule.is_valid(TriggerRule.NONE_SKIPPED)
assert TriggerRule.is_valid(TriggerRule.DUMMY)
assert TriggerRule.is_valid(TriggerRule.ALWAYS)
assert TriggerRule.is_valid(TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
assert len(TriggerRule.all_triggers()) == 12
assert len(TriggerRule.all_triggers()) == 13

with pytest.raises(ValueError):
TriggerRule("NOT_EXIST_TRIGGER_RULE")