diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index c56d14f44e..68aead7d71 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -33,11 +33,18 @@ def _get_active_graph(G): for stage in G: if not stage.locked: continue - for st in nx.dfs_postorder_nodes(G, stage): - if st == stage: - continue - if st in active: - active.remove_node(st) + active.remove_edges_from(G.out_edges(stage)) + for edge in G.out_edges(stage): + _, to_stage = edge + for node in nx.dfs_preorder_nodes(G, to_stage): + # NOTE: `in_degree` will return InDegreeView({}) if stage + # no longer exists in the `active` DAG. + if not active.in_degree(node): + # NOTE: if some edge no longer exists `remove_edges_from` + # will ignore it without error. + active.remove_edges_from(G.out_edges(node)) + active.remove_node(node) + return active diff --git a/tests/unit/test_repo.py b/tests/unit/repo/test_repo.py similarity index 100% rename from tests/unit/test_repo.py rename to tests/unit/repo/test_repo.py diff --git a/tests/unit/repo/test_reproduce.py b/tests/unit/repo/test_reproduce.py new file mode 100644 index 0000000000..aef2a3bd2e --- /dev/null +++ b/tests/unit/repo/test_reproduce.py @@ -0,0 +1,25 @@ +from dvc.repo.reproduce import _get_active_graph + + +def test_get_active_graph(tmp_dir, dvc): + pre_foo_stage, = tmp_dir.dvc_gen({"pre-foo": "pre-foo"}) + foo_stage = dvc.run(deps=["pre-foo"], outs=["foo"], cmd="echo foo > foo") + bar_stage = dvc.run(deps=["foo"], outs=["bar"], cmd="echo bar > bar") + baz_stage = dvc.run(deps=["foo"], outs=["baz"], cmd="echo baz > baz") + + dvc.lock_stage("bar.dvc") + + graph = dvc.graph + active_graph = _get_active_graph(graph) + assert active_graph.nodes == graph.nodes + assert set(active_graph.edges) == { + (foo_stage, pre_foo_stage), + (baz_stage, foo_stage), + } + + dvc.lock_stage("baz.dvc") + + graph = dvc.graph + active_graph = _get_active_graph(graph) + assert set(active_graph.nodes) == {bar_stage, baz_stage} + assert not active_graph.edges