From fb2872a35890d21be1db3d934038edc835be9f4f Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Tue, 30 May 2023 11:22:10 +0200 Subject: [PATCH 1/3] add unit tests for default_args overriding in task group Signed-off-by: Hussein Awala --- tests/decorators/test_task_group.py | 32 +++++++++++++++++++++++++++++ tests/utils/test_task_group.py | 25 ++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/tests/decorators/test_task_group.py b/tests/decorators/test_task_group.py index 38b54cf1c52f0..7f59fe181707a 100644 --- a/tests/decorators/test_task_group.py +++ b/tests/decorators/test_task_group.py @@ -17,11 +17,14 @@ # under the License. from __future__ import annotations +from datetime import timedelta + import pendulum import pytest from airflow.decorators import dag, task_group from airflow.models.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput, MappedArgument +from airflow.operators.empty import EmptyOperator from airflow.utils.task_group import MappedTaskGroup @@ -186,3 +189,32 @@ def tg(a, b): assert tg._expand_input == ListOfDictsExpandInput([{"b": "x"}, {"b": None}]) assert saved == {"a": 1, "b": MappedArgument(input=tg._expand_input, key="b")} + + +def test_override_dag_default_args(): + @dag( + dag_id="test_dag", + start_date=pendulum.parse("20200101"), + default_args={ + "retries": 1, + "owner": "x", + }, + ) + def pipeline(): + @task_group( + group_id="task_group", + default_args={ + "owner": "y", + "execution_timeout": timedelta(seconds=10), + }, + ) + def tg(): + EmptyOperator(task_id="task") + + tg() + + test_dag = pipeline() + test_task = test_dag.task_group_dict["task_group"].children["task_group.task"] + assert test_task.retries == 1 + assert test_task.owner == "y" + assert test_task.execution_timeout == timedelta(seconds=10) diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index ba6f174773f0a..3830dfe2b8d40 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +from datetime import timedelta + import pendulum import pytest @@ -1301,3 +1303,26 @@ def test_iter_tasks(): "section_2.task3", "section_2.bash_task", ] + + +def test_override_dag_default_args(): + with DAG( + dag_id="test_dag", + start_date=pendulum.parse("20200101"), + default_args={ + "retries": 1, + "owner": "x", + }, + ): + with TaskGroup( + group_id="task_group", + default_args={ + "owner": "y", + "execution_timeout": timedelta(seconds=10), + }, + ): + task = EmptyOperator(task_id="task") + + assert task.retries == 1 + assert task.owner == "y" + assert task.execution_timeout == timedelta(seconds=10) From c990fbe8348fe513998d47946c33d3b96a638790 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Tue, 30 May 2023 11:34:10 +0200 Subject: [PATCH 2/3] fix overriding default args in nested task groups Signed-off-by: Hussein Awala --- airflow/utils/task_group.py | 5 +++ tests/decorators/test_task_group.py | 42 ++++++++++++++++++++++++ tests/utils/test_task_group.py | 50 +++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+) diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 85ec1eb0d3b09..99942b10815e9 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -140,6 +140,7 @@ def __init__( if parent_group: parent_group.add(self) + self._update_default_args(parent_group) self.used_group_ids.add(self.group_id) if self.group_id: @@ -176,6 +177,10 @@ def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): else: self._group_id = f"{base}__{suffixes[-1] + 1}" + def _update_default_args(self, parent_group: TaskGroup): + if parent_group.default_args: + self.default_args.update(parent_group.default_args) + @classmethod def create_root(cls, dag: DAG) -> TaskGroup: """Create a root TaskGroup with no group_id or parent.""" diff --git a/tests/decorators/test_task_group.py b/tests/decorators/test_task_group.py index 7f59fe181707a..3462c3a1d83a5 100644 --- a/tests/decorators/test_task_group.py +++ b/tests/decorators/test_task_group.py @@ -218,3 +218,45 @@ def tg(): assert test_task.retries == 1 assert test_task.owner == "y" assert test_task.execution_timeout == timedelta(seconds=10) + + +def test_override_dag_default_args_nested_tg(): + @dag( + dag_id="test_dag", + start_date=pendulum.parse("20200101"), + default_args={ + "retries": 1, + "owner": "x", + }, + ) + def pipeline(): + @task_group( + group_id="task_group", + default_args={ + "owner": "y", + "execution_timeout": timedelta(seconds=10), + }, + ) + def tg(): + @task_group(group_id="nested_task_group") + def nested_tg(): + @task_group(group_id="another_task_group") + def another_tg(): + EmptyOperator(task_id="task") + + another_tg() + + nested_tg() + + tg() + + test_dag = pipeline() + test_task = ( + test_dag.task_group_dict["task_group"] + .children["task_group.nested_task_group"] + .children["task_group.nested_task_group.another_task_group"] + .children["task_group.nested_task_group.another_task_group.task"] + ) + assert test_task.retries == 1 + assert test_task.owner == "y" + assert test_task.execution_timeout == timedelta(seconds=10) diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 3830dfe2b8d40..c9927eb3ba257 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -1326,3 +1326,53 @@ def test_override_dag_default_args(): assert task.retries == 1 assert task.owner == "y" assert task.execution_timeout == timedelta(seconds=10) + + +def test_override_dag_default_args_in_nested_tg(): + with DAG( + dag_id="test_dag", + start_date=pendulum.parse("20200101"), + default_args={ + "retries": 1, + "owner": "x", + }, + ): + with TaskGroup( + group_id="task_group", + default_args={ + "owner": "y", + "execution_timeout": timedelta(seconds=10), + }, + ): + with TaskGroup(group_id="nested_task_group"): + task = EmptyOperator(task_id="task") + + assert task.retries == 1 + assert task.owner == "y" + assert task.execution_timeout == timedelta(seconds=10) + + +def test_override_dag_default_args_in_multi_level_nested_tg(): + with DAG( + dag_id="test_dag", + start_date=pendulum.parse("20200101"), + default_args={ + "retries": 1, + "owner": "x", + }, + ): + with TaskGroup( + group_id="task_group", + default_args={ + "owner": "y", + "execution_timeout": timedelta(seconds=10), + }, + ): + with TaskGroup(group_id="first_nested_task_group"): + with TaskGroup(group_id="second_nested_task_group"): + with TaskGroup(group_id="third_nested_task_group"): + task = EmptyOperator(task_id="task") + + assert task.retries == 1 + assert task.owner == "y" + assert task.execution_timeout == timedelta(seconds=10) From b7ddc2765964679e7c5e619782cf8333f1c9319b Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Tue, 30 May 2023 14:27:56 +0200 Subject: [PATCH 3/3] Update airflow/utils/task_group.py Co-authored-by: Ash Berlin-Taylor --- airflow/utils/task_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 99942b10815e9..4b8392be7b3f0 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -179,7 +179,7 @@ def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): def _update_default_args(self, parent_group: TaskGroup): if parent_group.default_args: - self.default_args.update(parent_group.default_args) + self.default_args = {**self.default_args, **parent_group.default_args} @classmethod def create_root(cls, dag: DAG) -> TaskGroup: