From ea60aef97f52410dd55eee1f2dcd709d3b5ce8fb Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Thu, 9 May 2024 12:06:37 -0700 Subject: [PATCH] Disable view_copy elimination for graph outputs (#3565) Summary: If the `view_copy` op is a graph output, leave it as a view_copy for now since the output pointer may be modified at runtime when deploying on device. Right now, the modified pointer would be ignored since the view_copy op will always point to its predecessor memory. cc chrismthompson jcoriell fengwang Differential Revision: D57132664 --- .../replace_view_copy_with_view_pass.py | 12 +++++-- exir/tests/test_passes.py | 7 ++-- exir/tests/test_remove_view_copy.py | 36 ++----------------- 3 files changed, 16 insertions(+), 39 deletions(-) diff --git a/exir/passes/replace_view_copy_with_view_pass.py b/exir/passes/replace_view_copy_with_view_pass.py index a9304f3eec8..8d3a2a32126 100644 --- a/exir/passes/replace_view_copy_with_view_pass.py +++ b/exir/passes/replace_view_copy_with_view_pass.py @@ -273,7 +273,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: if not isinstance(module, torch.fx.GraphModule): continue for node in module.graph.nodes: - if _is_view_copy(node): + # Note: We only replace view_copy nodes that are not output, since + # the output pointer could be modified at runtime (T187925929) + if _is_view_copy(node) and node.next.op != "output": base, _ = node.args node.target = _VIEW_OP @@ -298,7 +300,9 @@ def ensures(self, graph_module: torch.fx.GraphModule) -> None: if not isinstance(module, torch.fx.GraphModule): continue for node in module.graph.nodes: - assert not _is_view_copy(node) + # Note: We only replace view_copy nodes that are not output, since + # the output pointer could be modified at runtime (T187925929) + assert not (_is_view_copy(node) and node.next.op != "output") if node.op == "call_function" and node.target == _VIEW_OP: assert isinstance(node.meta["spec"], _ViewSpec) @@ -311,6 +315,8 @@ def requires(self, graph_module: torch.fx.GraphModule) -> None: if not isinstance(module, torch.fx.GraphModule): continue for node in module.graph.nodes: - if _is_view_copy(node): + # Note: We only replace view_copy nodes that are not output, since + # the output pointer could be modified at runtime (T187925929) + if _is_view_copy(node) and node.next.op != "output": base, size = node.args assert not _is_view_copy(base) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 9c5e4b59adc..0377f70a150 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1630,10 +1630,11 @@ def forward(self, x): assert gm_res is not None gm = gm_res.graph_module - # Check before transformation + # Check after transformation + # Note: one view copy is not replaced, because it's the output of the graph FileCheck().check_count( - "torch.ops.aten.view_copy.default", 0, exactly=True + "torch.ops.aten.view_copy.default", 1, exactly=True ).run(gm.code) - FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run( + FileCheck().check_count("executorch_exir_memory_view", 1, exactly=True).run( gm.code ) diff --git a/exir/tests/test_remove_view_copy.py b/exir/tests/test_remove_view_copy.py index 0c5b61f8d8f..3b7e46593c0 100644 --- a/exir/tests/test_remove_view_copy.py +++ b/exir/tests/test_remove_view_copy.py @@ -148,42 +148,12 @@ def test_spec(self) -> None: self.assertEqual( node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime ) - elif node.name == "aten_mul_tensor": - # aten_mul_tensor's lifetime is extended through aten_view_copy_default_2 (memory.view) to idx 7 - self.assertEqual(node.meta["spec"].lifetime, [4, 7]) - elif node.name == "aten_view_copy_default_2": - # aten_view_copy_default_2 is a memory.view of aten_mul_tensor - - # assert base is aten_mul_tensor - self.assertEqual(node.args[0].name, "aten_mul_tensor") - - # assert base and self are not const, do not have storage, - # but do have mem_id and mem_offset - self.assertFalse(node.args[0].meta["spec"].const) - self.assertTrue(node.args[0].meta["spec"].storage is None) - self.assertTrue(node.args[0].meta["spec"].mem_id is not None) - self.assertTrue(node.args[0].meta["spec"].mem_offset is not None) - - self.assertFalse(node.meta["spec"].const) - self.assertTrue(node.meta["spec"].storage is None) - self.assertTrue(node.meta["spec"].mem_id is not None) - self.assertTrue(node.meta["spec"].mem_offset is not None) - - # assert self and base mem_id, mem_offset, and lifetime matches - self.assertEqual( - node.meta["spec"].mem_id, node.args[0].meta["spec"].mem_id - ) - self.assertEqual( - node.meta["spec"].mem_offset, node.args[0].meta["spec"].mem_offset - ) - self.assertEqual( - node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime - ) # Test evalues in execution plan plan = etpm.executorch_program.execution_plan[0] self.assertEqual(plan.operators[0].name, "executorch_prim::et_view") self.assertEqual(plan.operators[1].name, "aten::mul") + self.assertEqual(plan.operators[2].name, "aten::view_copy") instructions = plan.chains[0].instructions self.assertEqual(len(instructions), 4) @@ -198,5 +168,5 @@ def test_spec(self) -> None: instructions[2].instr_args.op_index, 1 # pyre-ignore ) # aten:mul @ idx5 self.assertEqual( - instructions[3].instr_args.op_index, 0 # pyre-ignore - ) # view @ idx6 + instructions[3].instr_args.op_index, 2 # pyre-ignore + ) # aten:view_copy @ idx6