From fae0b49dac189534241086d7400255e4392daf9d Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 15 Mar 2023 16:21:21 +0000 Subject: [PATCH 1/4] Ensure that `dag.partial_subset` doesn't mutate task group properties We had a few properties that we failed to copy that we should have. To fix that in such a way that we don't miss things in the future I've converted it to use deepcopy everything by default and exclude `children` and `parent_group`. This also made me notice that we were not correctly setting `parent_group` after partial_subset anymore -- it clearly hasn't mattered, but we were setting a now-unused `_parent_group` attribute. --- airflow/models/dag.py | 16 ++++++++-------- tests/models/test_dag.py | 35 +++++++++++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 61f8b130fa488..1fe0436319ec2 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2215,21 +2215,21 @@ def _deepcopy_task(t) -> Operator: def filter_task_group(group, parent_group): """Exclude tasks not included in the subdag from the given TaskGroup.""" - copied = copy.copy(group) - copied.used_group_ids = set(copied.used_group_ids) - copied._parent_group = parent_group - - copied.children = {} + memo = {id(group.children): {}} + if parent_group: + memo[id(group.parent_group)] = parent_group + copied = copy.deepcopy(group, memo) + proxy = weakref.proxy(copied) for child in group.children.values(): if isinstance(child, AbstractOperator): if child.task_id in dag.task_dict: task = copied.children[child.task_id] = dag.task_dict[child.task_id] - task.task_group = weakref.proxy(copied) + task.task_group = proxy else: copied.used_group_ids.discard(child.task_id) else: - filtered_child = filter_task_group(child, copied) + filtered_child = filter_task_group(child, proxy) # Only include this child TaskGroup if it is non-empty. if filtered_child.children: @@ -2237,7 +2237,7 @@ def filter_task_group(group, parent_group): return copied - dag._task_group = filter_task_group(self._task_group, None) + dag._task_group = filter_task_group(self.task_group, None) # Removing upstream/downstream references to tasks and TaskGroups that did not make # the cut. diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index d5a8de4864346..46df64d043786 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -24,6 +24,7 @@ import pickle import re import sys +import weakref from contextlib import redirect_stdout from datetime import timedelta from pathlib import Path @@ -1381,7 +1382,7 @@ def test_duplicate_task_ids_for_same_task_is_allowed(self): assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3} assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3} - def test_sub_dag_updates_all_references_while_deepcopy(self): + def test_partial_subset_updates_all_references_while_deepcopy(self): with DAG("test_dag", start_date=DEFAULT_DATE) as dag: op1 = EmptyOperator(task_id="t1") op2 = EmptyOperator(task_id="t2") @@ -1389,11 +1390,37 @@ def test_sub_dag_updates_all_references_while_deepcopy(self): op1 >> op2 op2 >> op3 - sub_dag = dag.partial_subset("t2", include_upstream=True, include_downstream=False) - assert id(sub_dag.task_dict["t1"].downstream_list[0].dag) == id(sub_dag) + partial = dag.partial_subset("t2", include_upstream=True, include_downstream=False) + assert id(partial.task_dict["t1"].downstream_list[0].dag) == id(partial) # Copied DAG should not include unused task IDs in used_group_ids - assert "t3" not in sub_dag._task_group.used_group_ids + assert "t3" not in partial.task_group.used_group_ids + + def test_partial_subset_taskgroup_join_ids(self): + with DAG("test_dag", start_date=DEFAULT_DATE) as dag: + start = EmptyOperator(task_id="start") + with TaskGroup(group_id="outer", prefix_group_id=False) as outer_group: + with TaskGroup(group_id="tg1", prefix_group_id=False) as tg1: + EmptyOperator(task_id="t1") + with TaskGroup(group_id="tg2", prefix_group_id=False) as tg2: + EmptyOperator(task_id="t2") + + start >> tg1 >> tg2 + + # Pre-condition checks + task = dag.get_task("t2") + assert task.task_group.upstream_group_ids == {"tg1"} + assert isinstance(task.task_group.parent_group, weakref.ProxyType) + assert task.task_group.parent_group == outer_group + + partial = dag.partial_subset(["t2"], include_upstream=True, include_downstream=False) + copied_task = partial.get_task("t2") + assert copied_task.task_group.upstream_group_ids == {"tg1"} + assert isinstance(copied_task.task_group.parent_group, weakref.ProxyType) + assert copied_task.task_group.parent_group + + # Make sure we don't affect the original! + assert task.task_group.upstream_group_ids is not copied_task.task_group.upstream_group_ids def test_schedule_dag_no_previous_runs(self): """ From 0791566646d1a6e4d2b9bf59092c8d9bae43ebcc Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 15 Mar 2023 18:22:34 +0000 Subject: [PATCH 2/4] fixup! Ensure that `dag.partial_subset` doesn't mutate task group properties --- airflow/models/dag.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 1fe0436319ec2..b9e886ca57912 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2215,10 +2215,21 @@ def _deepcopy_task(t) -> Operator: def filter_task_group(group, parent_group): """Exclude tasks not included in the subdag from the given TaskGroup.""" - memo = {id(group.children): {}} + # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy + # and then manually deep copy the instances. (memo argyment to deepcopy only works for instances + # of classes, not "native" properties of an instance, ) + copied = copy.copy(group) + + memo[id(group.children)] = {} if parent_group: memo[id(group.parent_group)] = parent_group - copied = copy.deepcopy(group, memo) + for attr, value in copied.__dict__.items(): + if id(value) in memo: + value = memo[id(value)] + else: + value = copy.deepcopy(value, memo) + copied.__dict__[attr] = value + proxy = weakref.proxy(copied) for child in group.children.values(): From 05e8700f100500a854c009713c45c4f1424ecd73 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 15 Mar 2023 18:23:59 +0000 Subject: [PATCH 3/4] Update airflow/models/dag.py --- airflow/models/dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index b9e886ca57912..0564cc15b3b57 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2216,7 +2216,7 @@ def _deepcopy_task(t) -> Operator: def filter_task_group(group, parent_group): """Exclude tasks not included in the subdag from the given TaskGroup.""" # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy - # and then manually deep copy the instances. (memo argyment to deepcopy only works for instances + # and then manually deep copy the instances. (memo argument to deepcopy only works for instances # of classes, not "native" properties of an instance, ) copied = copy.copy(group) From 79b9b4ebc4aa341842d8d4d313f0d9246dee9e93 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 15 Mar 2023 20:04:42 +0000 Subject: [PATCH 4/4] Update airflow/models/dag.py Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/models/dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 0564cc15b3b57..1e653ad19e94e 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2217,7 +2217,7 @@ def filter_task_group(group, parent_group): """Exclude tasks not included in the subdag from the given TaskGroup.""" # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy # and then manually deep copy the instances. (memo argument to deepcopy only works for instances - # of classes, not "native" properties of an instance, ) + # of classes, not "native" properties of an instance) copied = copy.copy(group) memo[id(group.children)] = {}