Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2215,29 +2215,40 @@ def _deepcopy_task(t) -> Operator:

def filter_task_group(group, parent_group):
"""Exclude tasks not included in the subdag from the given TaskGroup."""
# We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy
# and then manually deep copy the instances. (memo argument to deepcopy only works for instances
# of classes, not "native" properties of an instance)
copied = copy.copy(group)
copied.used_group_ids = set(copied.used_group_ids)
copied._parent_group = parent_group

copied.children = {}
memo[id(group.children)] = {}
if parent_group:
memo[id(group.parent_group)] = parent_group
for attr, value in copied.__dict__.items():
if id(value) in memo:
value = memo[id(value)]
else:
value = copy.deepcopy(value, memo)
copied.__dict__[attr] = value

proxy = weakref.proxy(copied)

for child in group.children.values():
if isinstance(child, AbstractOperator):
if child.task_id in dag.task_dict:
task = copied.children[child.task_id] = dag.task_dict[child.task_id]
task.task_group = weakref.proxy(copied)
task.task_group = proxy
else:
copied.used_group_ids.discard(child.task_id)
else:
filtered_child = filter_task_group(child, copied)
filtered_child = filter_task_group(child, proxy)

# Only include this child TaskGroup if it is non-empty.
if filtered_child.children:
copied.children[child.group_id] = filtered_child

return copied

dag._task_group = filter_task_group(self._task_group, None)
dag._task_group = filter_task_group(self.task_group, None)

# Removing upstream/downstream references to tasks and TaskGroups that did not make
# the cut.
Expand Down
35 changes: 31 additions & 4 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pickle
import re
import sys
import weakref
from contextlib import redirect_stdout
from datetime import timedelta
from pathlib import Path
Expand Down Expand Up @@ -1381,19 +1382,45 @@ def test_duplicate_task_ids_for_same_task_is_allowed(self):
assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3}
assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3}

def test_sub_dag_updates_all_references_while_deepcopy(self):
def test_partial_subset_updates_all_references_while_deepcopy(self):
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = EmptyOperator(task_id="t1")
op2 = EmptyOperator(task_id="t2")
op3 = EmptyOperator(task_id="t3")
op1 >> op2
op2 >> op3

sub_dag = dag.partial_subset("t2", include_upstream=True, include_downstream=False)
assert id(sub_dag.task_dict["t1"].downstream_list[0].dag) == id(sub_dag)
partial = dag.partial_subset("t2", include_upstream=True, include_downstream=False)
assert id(partial.task_dict["t1"].downstream_list[0].dag) == id(partial)

# Copied DAG should not include unused task IDs in used_group_ids
assert "t3" not in sub_dag._task_group.used_group_ids
assert "t3" not in partial.task_group.used_group_ids

def test_partial_subset_taskgroup_join_ids(self):
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
start = EmptyOperator(task_id="start")
with TaskGroup(group_id="outer", prefix_group_id=False) as outer_group:
with TaskGroup(group_id="tg1", prefix_group_id=False) as tg1:
EmptyOperator(task_id="t1")
with TaskGroup(group_id="tg2", prefix_group_id=False) as tg2:
EmptyOperator(task_id="t2")

start >> tg1 >> tg2

# Pre-condition checks
task = dag.get_task("t2")
assert task.task_group.upstream_group_ids == {"tg1"}
assert isinstance(task.task_group.parent_group, weakref.ProxyType)
assert task.task_group.parent_group == outer_group

partial = dag.partial_subset(["t2"], include_upstream=True, include_downstream=False)
copied_task = partial.get_task("t2")
assert copied_task.task_group.upstream_group_ids == {"tg1"}
assert isinstance(copied_task.task_group.parent_group, weakref.ProxyType)
assert copied_task.task_group.parent_group

# Make sure we don't affect the original!
assert task.task_group.upstream_group_ids is not copied_task.task_group.upstream_group_ids

def test_schedule_dag_no_previous_runs(self):
"""
Expand Down