Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions exir/passes/replace_view_copy_with_view_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if _is_view_copy(node):
# Note: We only replace view_copy nodes that are not output, since
# the output pointer could be modified at runtime (T187925929)
if _is_view_copy(node) and node.next.op != "output":
base, _ = node.args
node.target = _VIEW_OP

Expand All @@ -298,7 +300,9 @@ def ensures(self, graph_module: torch.fx.GraphModule) -> None:
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
assert not _is_view_copy(node)
# Note: We only replace view_copy nodes that are not output, since
# the output pointer could be modified at runtime (T187925929)
assert not (_is_view_copy(node) and node.next.op != "output")
if node.op == "call_function" and node.target == _VIEW_OP:
assert isinstance(node.meta["spec"], _ViewSpec)

Expand All @@ -311,6 +315,8 @@ def requires(self, graph_module: torch.fx.GraphModule) -> None:
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if _is_view_copy(node):
# Note: We only replace view_copy nodes that are not output, since
# the output pointer could be modified at runtime (T187925929)
if _is_view_copy(node) and node.next.op != "output":
base, size = node.args
assert not _is_view_copy(base)
7 changes: 4 additions & 3 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,10 +1630,11 @@ def forward(self, x):
assert gm_res is not None
gm = gm_res.graph_module

# Check before transformation
# Check after transformation
# Note: one view copy is not replaced, because it's the output of the graph
FileCheck().check_count(
"torch.ops.aten.view_copy.default", 0, exactly=True
"torch.ops.aten.view_copy.default", 1, exactly=True
).run(gm.code)
FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run(
FileCheck().check_count("executorch_exir_memory_view", 1, exactly=True).run(
gm.code
)
36 changes: 3 additions & 33 deletions exir/tests/test_remove_view_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,42 +148,12 @@ def test_spec(self) -> None:
self.assertEqual(
node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime
)
elif node.name == "aten_mul_tensor":
# aten_mul_tensor's lifetime is extended through aten_view_copy_default_2 (memory.view) to idx 7
self.assertEqual(node.meta["spec"].lifetime, [4, 7])
elif node.name == "aten_view_copy_default_2":
# aten_view_copy_default_2 is a memory.view of aten_mul_tensor

# assert base is aten_mul_tensor
self.assertEqual(node.args[0].name, "aten_mul_tensor")

# assert base and self are not const, do not have storage,
# but do have mem_id and mem_offset
self.assertFalse(node.args[0].meta["spec"].const)
self.assertTrue(node.args[0].meta["spec"].storage is None)
self.assertTrue(node.args[0].meta["spec"].mem_id is not None)
self.assertTrue(node.args[0].meta["spec"].mem_offset is not None)

self.assertFalse(node.meta["spec"].const)
self.assertTrue(node.meta["spec"].storage is None)
self.assertTrue(node.meta["spec"].mem_id is not None)
self.assertTrue(node.meta["spec"].mem_offset is not None)

# assert self and base mem_id, mem_offset, and lifetime matches
self.assertEqual(
node.meta["spec"].mem_id, node.args[0].meta["spec"].mem_id
)
self.assertEqual(
node.meta["spec"].mem_offset, node.args[0].meta["spec"].mem_offset
)
self.assertEqual(
node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime
)

# Test evalues in execution plan
plan = etpm.executorch_program.execution_plan[0]
self.assertEqual(plan.operators[0].name, "executorch_prim::et_view")
self.assertEqual(plan.operators[1].name, "aten::mul")
self.assertEqual(plan.operators[2].name, "aten::view_copy")

instructions = plan.chains[0].instructions
self.assertEqual(len(instructions), 4)
Expand All @@ -198,5 +168,5 @@ def test_spec(self) -> None:
instructions[2].instr_args.op_index, 1 # pyre-ignore
) # aten:mul @ idx5
self.assertEqual(
instructions[3].instr_args.op_index, 0 # pyre-ignore
) # view @ idx6
instructions[3].instr_args.op_index, 2 # pyre-ignore
) # aten:view_copy @ idx6