Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions exir/passes/replace_view_copy_with_view_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in module.graph.nodes:
# Note: We only replace view_copy nodes that are not output, since
# the output pointer could be modified at runtime (T187925929)
if _is_view_copy(node) and node.next.op != "output":
if _is_view_copy(node) and all(u.op != "output" for u in node.users):
base, _ = node.args
node.target = _VIEW_OP

Expand All @@ -302,7 +302,9 @@ def ensures(self, graph_module: torch.fx.GraphModule) -> None:
for node in module.graph.nodes:
# Note: We only replace view_copy nodes that are not output, since
# the output pointer could be modified at runtime (T187925929)
assert not (_is_view_copy(node) and node.next.op != "output")
assert not (
_is_view_copy(node) and all(u.op != "output" for u in node.users)
)
if node.op == "call_function" and node.target == _VIEW_OP:
assert isinstance(node.meta["spec"], _ViewSpec)

Expand All @@ -317,6 +319,6 @@ def requires(self, graph_module: torch.fx.GraphModule) -> None:
for node in module.graph.nodes:
# Note: We only replace view_copy nodes that are not output, since
# the output pointer could be modified at runtime (T187925929)
if _is_view_copy(node) and node.next.op != "output":
if _is_view_copy(node) and all(u.op != "output" for u in node.users):
base, size = node.args
assert not _is_view_copy(base)
9 changes: 5 additions & 4 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,9 @@ def __init__(self):
def forward(self, x):
o1 = torch.ops.aten.view_copy.default(x, [1])
o2 = torch.ops.aten.view_copy.default(self.parameter, [1])
return o1, o2
# view_copys at the end of a function are not replaced, so add
# a computation before the end of the graph.
return torch.ops.aten.add.Tensor(o1, o2)

ep = torch.export.export(
TestViewCopies(),
Expand Down Expand Up @@ -1631,10 +1633,9 @@ def forward(self, x):
gm = gm_res.graph_module

# Check after transformation
# Note: one view copy is not replaced, because it's the output of the graph
FileCheck().check_count(
"torch.ops.aten.view_copy.default", 1, exactly=True
"torch.ops.aten.view_copy.default", 0, exactly=True
).run(gm.code)
FileCheck().check_count("executorch_exir_memory_view", 1, exactly=True).run(
FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run(
gm.code
)
19 changes: 14 additions & 5 deletions exir/tests/test_remove_view_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def forward(self, x):
) # removed, lifetime of mul.Tensor will be extended
v4 = torch.ops.aten.mul.Tensor(v3, self.parameter2)
v5 = v4.view(6, 5) # not removed, output of the graph
return v5
v6 = v4.view(2, 15) # not removed, output of the graph
return v5, v6

def get_example_inputs(self):
return (torch.rand(5, 6),)
Expand Down Expand Up @@ -87,10 +88,15 @@ def test_output_matches(self) -> None:
),
)

out_remove = etpm_remove.exported_program().module()(*example_inputs)
out_no_remove = etpm_no_remove.exported_program().module()(*example_inputs)
out_remove_v5, out_remove_v6 = etpm_remove.exported_program().module()(
*example_inputs
)
out_no_remove_v5, out_no_remove_v6 = etpm_no_remove.exported_program().module()(
*example_inputs
)

self.assertTrue(torch.allclose(out_remove, out_no_remove))
self.assertTrue(torch.allclose(out_remove_v5, out_no_remove_v5))
self.assertTrue(torch.allclose(out_remove_v6, out_no_remove_v6))

def test_spec(self) -> None:
model = TestModel1()
Expand Down Expand Up @@ -196,7 +202,7 @@ def test_spec(self) -> None:
self.assertEqual(plan.operators[2].name, "aten::view_copy")

instructions = plan.chains[0].instructions
self.assertEqual(len(instructions), 6)
self.assertEqual(len(instructions), 7)

self.assertEqual(
instructions[0].instr_args.op_index, 0 # pyre-ignore
Expand All @@ -216,3 +222,6 @@ def test_spec(self) -> None:
self.assertEqual(
instructions[5].instr_args.op_index, 2 # pyre-ignore
) # aten:view_copy @ idx11
self.assertEqual(
instructions[6].instr_args.op_index, 2 # pyre-ignore
) # aten:view_copy @ idx11