diff --git a/exir/passes/replace_view_copy_with_view_pass.py b/exir/passes/replace_view_copy_with_view_pass.py index 8d3a2a32126..378b9332119 100644 --- a/exir/passes/replace_view_copy_with_view_pass.py +++ b/exir/passes/replace_view_copy_with_view_pass.py @@ -275,7 +275,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for node in module.graph.nodes: # Note: We only replace view_copy nodes that are not output, since # the output pointer could be modified at runtime (T187925929) - if _is_view_copy(node) and node.next.op != "output": + if _is_view_copy(node) and all(u.op != "output" for u in node.users): base, _ = node.args node.target = _VIEW_OP @@ -302,7 +302,9 @@ def ensures(self, graph_module: torch.fx.GraphModule) -> None: for node in module.graph.nodes: # Note: We only replace view_copy nodes that are not output, since # the output pointer could be modified at runtime (T187925929) - assert not (_is_view_copy(node) and node.next.op != "output") + assert not ( + _is_view_copy(node) and all(u.op != "output" for u in node.users) + ) if node.op == "call_function" and node.target == _VIEW_OP: assert isinstance(node.meta["spec"], _ViewSpec) @@ -317,6 +319,6 @@ def requires(self, graph_module: torch.fx.GraphModule) -> None: for node in module.graph.nodes: # Note: We only replace view_copy nodes that are not output, since # the output pointer could be modified at runtime (T187925929) - if _is_view_copy(node) and node.next.op != "output": + if _is_view_copy(node) and all(u.op != "output" for u in node.users): base, size = node.args assert not _is_view_copy(base) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 0377f70a150..f65ccff13b0 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1602,7 +1602,9 @@ def __init__(self): def forward(self, x): o1 = torch.ops.aten.view_copy.default(x, [1]) o2 = torch.ops.aten.view_copy.default(self.parameter, [1]) - return o1, o2 + # view_copys at the end of a function are not replaced, so add + # a computation before the end of the graph. + return torch.ops.aten.add.Tensor(o1, o2) ep = torch.export.export( TestViewCopies(), @@ -1631,10 +1633,9 @@ def forward(self, x): gm = gm_res.graph_module # Check after transformation - # Note: one view copy is not replaced, because it's the output of the graph FileCheck().check_count( - "torch.ops.aten.view_copy.default", 1, exactly=True + "torch.ops.aten.view_copy.default", 0, exactly=True ).run(gm.code) - FileCheck().check_count("executorch_exir_memory_view", 1, exactly=True).run( + FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run( gm.code ) diff --git a/exir/tests/test_remove_view_copy.py b/exir/tests/test_remove_view_copy.py index b3ad1f7d5a7..f64a1f19981 100644 --- a/exir/tests/test_remove_view_copy.py +++ b/exir/tests/test_remove_view_copy.py @@ -32,7 +32,8 @@ def forward(self, x): ) # removed, lifetime of mul.Tensor will be extended v4 = torch.ops.aten.mul.Tensor(v3, self.parameter2) v5 = v4.view(6, 5) # not removed, output of the graph - return v5 + v6 = v4.view(2, 15) # not removed, output of the graph + return v5, v6 def get_example_inputs(self): return (torch.rand(5, 6),) @@ -87,10 +88,15 @@ def test_output_matches(self) -> None: ), ) - out_remove = etpm_remove.exported_program().module()(*example_inputs) - out_no_remove = etpm_no_remove.exported_program().module()(*example_inputs) + out_remove_v5, out_remove_v6 = etpm_remove.exported_program().module()( + *example_inputs + ) + out_no_remove_v5, out_no_remove_v6 = etpm_no_remove.exported_program().module()( + *example_inputs + ) - self.assertTrue(torch.allclose(out_remove, out_no_remove)) + self.assertTrue(torch.allclose(out_remove_v5, out_no_remove_v5)) + self.assertTrue(torch.allclose(out_remove_v6, out_no_remove_v6)) def test_spec(self) -> None: model = TestModel1() @@ -196,7 +202,7 @@ def test_spec(self) -> None: self.assertEqual(plan.operators[2].name, "aten::view_copy") instructions = plan.chains[0].instructions - self.assertEqual(len(instructions), 6) + self.assertEqual(len(instructions), 7) self.assertEqual( instructions[0].instr_args.op_index, 0 # pyre-ignore @@ -216,3 +222,6 @@ def test_spec(self) -> None: self.assertEqual( instructions[5].instr_args.op_index, 2 # pyre-ignore ) # aten:view_copy @ idx11 + self.assertEqual( + instructions[6].instr_args.op_index, 2 # pyre-ignore + ) # aten:view_copy @ idx11