From 81ab3fb87f642852a07fbabfea8c2633891669dd Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Tue, 14 May 2024 10:07:17 -0700 Subject: [PATCH] Fix memory.view insertion except for output nodes (#3602) Summary: The previous implementation of ignoring `view_copy` on outputs was incorrect in that it only checked `node.next` instead of all users of the node. `node.next` just selects the next node in topological order, which may or may not be the output if there is more than one output. In the case of more than one output, the next node may not be related at all! Check if any of the users of the node are an output instead. Reviewed By: metascroy, mcremon-meta Differential Revision: D57299853 --- .../replace_view_copy_with_view_pass.py | 8 +++++--- exir/tests/test_passes.py | 9 +++++---- exir/tests/test_remove_view_copy.py | 19 ++++++++++++++----- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/exir/passes/replace_view_copy_with_view_pass.py b/exir/passes/replace_view_copy_with_view_pass.py index 8d3a2a32126..378b9332119 100644 --- a/exir/passes/replace_view_copy_with_view_pass.py +++ b/exir/passes/replace_view_copy_with_view_pass.py @@ -275,7 +275,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for node in module.graph.nodes: # Note: We only replace view_copy nodes that are not output, since # the output pointer could be modified at runtime (T187925929) - if _is_view_copy(node) and node.next.op != "output": + if _is_view_copy(node) and all(u.op != "output" for u in node.users): base, _ = node.args node.target = _VIEW_OP @@ -302,7 +302,9 @@ def ensures(self, graph_module: torch.fx.GraphModule) -> None: for node in module.graph.nodes: # Note: We only replace view_copy nodes that are not output, since # the output pointer could be modified at runtime (T187925929) - assert not (_is_view_copy(node) and node.next.op != "output") + assert not ( + _is_view_copy(node) and all(u.op != "output" for u in node.users) + ) if node.op == "call_function" and node.target == _VIEW_OP: assert isinstance(node.meta["spec"], _ViewSpec) @@ -317,6 +319,6 @@ def requires(self, graph_module: torch.fx.GraphModule) -> None: for node in module.graph.nodes: # Note: We only replace view_copy nodes that are not output, since # the output pointer could be modified at runtime (T187925929) - if _is_view_copy(node) and node.next.op != "output": + if _is_view_copy(node) and all(u.op != "output" for u in node.users): base, size = node.args assert not _is_view_copy(base) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 0377f70a150..f65ccff13b0 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1602,7 +1602,9 @@ def __init__(self): def forward(self, x): o1 = torch.ops.aten.view_copy.default(x, [1]) o2 = torch.ops.aten.view_copy.default(self.parameter, [1]) - return o1, o2 + # view_copys at the end of a function are not replaced, so add + # a computation before the end of the graph. + return torch.ops.aten.add.Tensor(o1, o2) ep = torch.export.export( TestViewCopies(), @@ -1631,10 +1633,9 @@ def forward(self, x): gm = gm_res.graph_module # Check after transformation - # Note: one view copy is not replaced, because it's the output of the graph FileCheck().check_count( - "torch.ops.aten.view_copy.default", 1, exactly=True + "torch.ops.aten.view_copy.default", 0, exactly=True ).run(gm.code) - FileCheck().check_count("executorch_exir_memory_view", 1, exactly=True).run( + FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run( gm.code ) diff --git a/exir/tests/test_remove_view_copy.py b/exir/tests/test_remove_view_copy.py index b3ad1f7d5a7..f64a1f19981 100644 --- a/exir/tests/test_remove_view_copy.py +++ b/exir/tests/test_remove_view_copy.py @@ -32,7 +32,8 @@ def forward(self, x): ) # removed, lifetime of mul.Tensor will be extended v4 = torch.ops.aten.mul.Tensor(v3, self.parameter2) v5 = v4.view(6, 5) # not removed, output of the graph - return v5 + v6 = v4.view(2, 15) # not removed, output of the graph + return v5, v6 def get_example_inputs(self): return (torch.rand(5, 6),) @@ -87,10 +88,15 @@ def test_output_matches(self) -> None: ), ) - out_remove = etpm_remove.exported_program().module()(*example_inputs) - out_no_remove = etpm_no_remove.exported_program().module()(*example_inputs) + out_remove_v5, out_remove_v6 = etpm_remove.exported_program().module()( + *example_inputs + ) + out_no_remove_v5, out_no_remove_v6 = etpm_no_remove.exported_program().module()( + *example_inputs + ) - self.assertTrue(torch.allclose(out_remove, out_no_remove)) + self.assertTrue(torch.allclose(out_remove_v5, out_no_remove_v5)) + self.assertTrue(torch.allclose(out_remove_v6, out_no_remove_v6)) def test_spec(self) -> None: model = TestModel1() @@ -196,7 +202,7 @@ def test_spec(self) -> None: self.assertEqual(plan.operators[2].name, "aten::view_copy") instructions = plan.chains[0].instructions - self.assertEqual(len(instructions), 6) + self.assertEqual(len(instructions), 7) self.assertEqual( instructions[0].instr_args.op_index, 0 # pyre-ignore @@ -216,3 +222,6 @@ def test_spec(self) -> None: self.assertEqual( instructions[5].instr_args.op_index, 2 # pyre-ignore ) # aten:view_copy @ idx11 + self.assertEqual( + instructions[6].instr_args.op_index, 2 # pyre-ignore + ) # aten:view_copy @ idx11