From 21236e6e6ff8ede79e630c54f9161646c8f13ab1 Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Fri, 6 Dec 2024 23:36:11 +0200 Subject: [PATCH] Prevent using `trigger_rule=TriggerRule.ALWAYS` in a task-generated mapping within bare tasks --- airflow/decorators/base.py | 21 ++++++++++ .../dynamic-task-mapping.rst | 10 +++-- newsfragments/44751.bugfix.rst | 1 + .../src/airflow/sdk/definitions/taskgroup.py | 4 +- tests/decorators/test_mapped.py | 38 +++++++++++++++++++ tests/decorators/test_task_group.py | 5 ++- 6 files changed, 72 insertions(+), 7 deletions(-) create mode 100644 newsfragments/44751.bugfix.rst diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 9a9498d49c307..9b303e1a703e7 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -398,6 +398,12 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]): super()._validate_arg_names(func, kwargs) def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg: + if self.kwargs.get("trigger_rule") == TriggerRule.ALWAYS and any( + [isinstance(expanded, XComArg) for expanded in map_kwargs.values()] + ): + raise ValueError( + "Task-generated mapping within a task using 'expand' is not allowed with trigger rule 'always'." + ) if not map_kwargs: raise TypeError("no arguments to expand against") self._validate_arg_names("expand", map_kwargs) @@ -411,6 +417,21 @@ def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg: return self._expand(DictOfListsExpandInput(map_kwargs), strict=False) def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: + if ( + self.kwargs.get("trigger_rule") == TriggerRule.ALWAYS + and not isinstance(kwargs, XComArg) + and any( + [ + isinstance(v, XComArg) + for kwarg in kwargs + if not isinstance(kwarg, XComArg) + for v in kwarg.values() + ] + ) + ): + raise ValueError( + "Task-generated mapping within a task using 'expand_kwargs' is not allowed with trigger rule 'always'." + ) if isinstance(kwargs, Sequence): for item in kwargs: if not isinstance(item, (XComArg, Mapping)): diff --git a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst index 426a720781ec8..00fb5b473e78d 100644 --- a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst +++ b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst @@ -84,10 +84,6 @@ The grid view also provides visibility into your mapped tasks in the details pan Although we show a "reduce" task here (``sum_it``) you don't have to have one, the mapped tasks will still be executed even if they have no downstream tasks. -.. warning:: ``TriggerRule.ALWAYS`` cannot be utilized in expanded tasks - - Assigning ``trigger_rule=TriggerRule.ALWAYS`` in expanded tasks is forbidden, as expanded parameters will be undefined with the task's immediate execution. - This is enforced at the time of the DAG parsing, and will raise an error if you try to use it. Task-generated Mapping ---------------------- @@ -113,6 +109,12 @@ The above examples we've shown could all be achieved with a ``for`` loop in the The ``make_list`` task runs as a normal task and must return a list or dict (see `What data types can be expanded?`_), and then the ``consumer`` task will be called four times, once with each value in the return of ``make_list``. +.. warning:: Task-generated mapping cannot be utilized with ``TriggerRule.ALWAYS`` + + Assigning ``trigger_rule=TriggerRule.ALWAYS`` in task-generated mapping is not allowed, as expanded parameters are undefined with the task's immediate execution. + This is enforced at the time of the DAG parsing, for both tasks and mapped tasks groups, and will raise an error if you try to use it. + In the recent example, setting ``trigger_rule=TriggerRule.ALWAYS`` in the ``consumer`` task will raise an error since ``make_list`` is a task-generated mapping. + Repeated mapping ---------------- diff --git a/newsfragments/44751.bugfix.rst b/newsfragments/44751.bugfix.rst new file mode 100644 index 0000000000000..1ca32178be1c5 --- /dev/null +++ b/newsfragments/44751.bugfix.rst @@ -0,0 +1 @@ +``TriggerRule.ALWAYS`` cannot be utilized within a task-generated mapping, either in bare tasks (fixed in this PR) or mapped task groups (fixed in PR #44368). The issue with doing so, is that the task is immediately executed without waiting for the upstreams's mapping results, which certainly leads to failure of the task. This fix avoids it by raising an exception when it is detected during DAG parsing. diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py index fd02a4c94e714..52b30ba31f8af 100644 --- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py @@ -597,7 +597,9 @@ def __iter__(self): for child in self.children.values(): if isinstance(child, AbstractOperator) and child.trigger_rule == TriggerRule.ALWAYS: - raise ValueError("Tasks in a mapped task group cannot have trigger_rule set to 'ALWAYS'") + raise ValueError( + "Task-generated mapping within a mapped task group is not allowed with trigger rule 'always'" + ) yield from self._iter_child(child) def iter_mapped_dependencies(self) -> Iterator[DAGNode]: diff --git a/tests/decorators/test_mapped.py b/tests/decorators/test_mapped.py index 9bd59f03be7e1..4b3faf119fd81 100644 --- a/tests/decorators/test_mapped.py +++ b/tests/decorators/test_mapped.py @@ -61,3 +61,41 @@ def f(x: int, y: int) -> int: xcoms.add(ti.xcom_pull(session=session, task_ids=ti.task_id, map_indexes=ti.map_index)) assert xcoms == {11, 12, 13} + + +@pytest.mark.db_test +def test_fail_task_generated_mapping_with_trigger_rule_always__exapnd(dag_maker, session): + with DAG(dag_id="d", schedule=None, start_date=DEFAULT_DATE): + + @task + def get_input(): + return ["world", "moon"] + + @task(trigger_rule="always") + def hello(input): + print(f"Hello, {input}") + + with pytest.raises( + ValueError, + match="Task-generated mapping within a task using 'expand' is not allowed with trigger rule 'always'", + ): + hello.expand(input=get_input()) + + +@pytest.mark.db_test +def test_fail_task_generated_mapping_with_trigger_rule_always__exapnd_kwargs(dag_maker, session): + with DAG(dag_id="d", schedule=None, start_date=DEFAULT_DATE): + + @task + def get_input(): + return ["world", "moon"] + + @task(trigger_rule="always") + def hello(input, input2): + print(f"Hello, {input}, {input2}") + + with pytest.raises( + ValueError, + match="Task-generated mapping within a task using 'expand_kwargs' is not allowed with trigger rule 'always'", + ): + hello.expand_kwargs([{"input": get_input(), "input2": get_input()}]) diff --git a/tests/decorators/test_task_group.py b/tests/decorators/test_task_group.py index 2dab23ca38fc7..ce1b518a8ff59 100644 --- a/tests/decorators/test_task_group.py +++ b/tests/decorators/test_task_group.py @@ -135,7 +135,7 @@ def tg(): @pytest.mark.db_test -def test_expand_fail_trigger_rule_always(dag_maker, session): +def test_fail_task_generated_mapping_with_trigger_rule_always(dag_maker, session): @dag(schedule=None, start_date=pendulum.datetime(2022, 1, 1)) def pipeline(): @task @@ -151,7 +151,8 @@ def tg(param): t1(param) with pytest.raises( - ValueError, match="Tasks in a mapped task group cannot have trigger_rule set to 'ALWAYS'" + ValueError, + match="Task-generated mapping within a mapped task group is not allowed with trigger rule 'always'", ): tg.expand(param=get_param())