diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 39d62036a37..012f109f313 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -12,12 +12,10 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch -import torch.nn as nn -from executorch.backends.cadence.aot.compiler import export_to_edge from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass from executorch.backends.cadence.aot.graph_builder import GraphBuilder -from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match +from executorch.backends.cadence.aot.pass_utils import count_node from executorch.backends.cadence.aot.remove_ops import ( RemoveAliasCopyOpPass, RemoveBranchedQuantDequant, @@ -444,89 +442,96 @@ def test_remove_nop_quant_dequant(self): ) def test_remove_nop_aten_linalg_vector_norm(self): - class LinalgVectorNorm(torch.nn.Module): - def forward(self, x: torch.Tensor): - return torch.linalg.vector_norm(x, 2, [0, 1], True) - - model = LinalgVectorNorm() - x = torch.randn([1, 1, 128]) - inputs = (x,) - - graph_module = ( - export_to_edge( - model, - inputs, - ) - .exported_program() - .graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 1, 128, dtype=torch.float32)) + linalg_vector_norm = builder.call_operator( + op=exir_ops.edge.aten.linalg_vector_norm.default, args=(x, 2, [0, 1], True) ) - - graph_module = none_throws( - RemoveNopLinalgVectorNormOpPass()(graph_module) + builder.output([linalg_vector_norm]) + original = builder.get_graph_module() + graph_after_passes = none_throws( + RemoveNopLinalgVectorNormOpPass()(original) ).graph_module - - # Expect the linalg_vector_norm op to be removed by the pass self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.linalg_vector_norm.default), + count_node( + graph_after_passes, exir_ops.edge.aten.linalg_vector_norm.default + ), 0, ) def test_remove_permutes_around_elemwise_ops_add(self) -> None: - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(8, 8, 1, bias=False) - - def forward(self, x): - x = self.conv(x) - x = torch.permute(x, [0, 3, 1, 2]) - x = torch.add(x, x) - x = torch.permute(x, [0, 2, 3, 1]) - x = self.conv(x) - return x - - inputs = (torch.randn(1, 8, 4, 4),) - graph_module = export_to_edge(M(), inputs).exported_program().graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 8, 4, 4, dtype=torch.float32)) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 3, 1, 2]) + ) + add = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, args=(permute, permute) + ) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(add, [0, 2, 3, 1]) + ) + builder.output([permute]) + original = builder.get_graph_module() p = RemovePermutesAroundElementwiseOps() - graph_module = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 + ) + def test_keep_permutes_around_elemwise_ops_add(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 8, 4, 4, dtype=torch.float32)) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [2, 1, 0, 3]) + ) + add = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, args=(permute, permute) + ) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(add, [0, 1, 3, 2]) + ) + builder.output([permute]) + original = builder.get_graph_module() + p = RemovePermutesAroundElementwiseOps() + graph_after_passes = cast(PassResult, p(original)).graph_module + # Ensure no permutes were removed, since the dimensions don't fit the expected pattern self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0 + count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 2 ) def test_remove_permutes_around_elemwise_ops_add_mean(self) -> None: - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv2d = nn.Conv2d(8, 8, 1) - - def forward(self, x, y): - x = self.conv2d(x) - y = self.conv2d(y) - x = torch.permute(x, [0, 3, 1, 2]) - y = torch.permute(y, [0, 3, 1, 2]) - z = torch.add(x, y) - z = torch.mean(z, dim=[-1, -3], keepdim=True) - z = torch.permute(z, [0, 2, 3, 1]) - z = self.conv2d(z) - return z - - inputs = (torch.randn(1, 8, 4, 4), torch.randn(1, 8, 4, 4)) - graph_module = export_to_edge(M(), inputs).exported_program().graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 8, 4, 4, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn(1, 8, 4, 4, dtype=torch.float32)) + permute_x = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 3, 1, 2]) + ) + permute_y = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(y, [0, 3, 1, 2]) + ) + add = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, args=(permute_x, permute_y) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(add, [3, 1], True) + ) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(mean, [0, 2, 3, 1]) + ) + builder.output([permute]) + original = builder.get_graph_module() p = RemovePermutesAroundElementwiseOps() - graph_module = cast(PassResult, p(graph_module)).graph_module - + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0 + count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 ) - - # verify that mean was updated correctly - mean = [ + mean_op = [ n - for n in graph_module.graph.nodes + for n in graph_after_passes.graph.nodes if n.target == exir_ops.edge.aten.mean.dim ][0] - self.assertEqual(mean.args[1], [2, 3]) + self.assertEqual(mean_op.args[1], [2, 3]) def test_remove_permutes_around_elemwise_ops_slice(self) -> None: builder = GraphBuilder() @@ -544,86 +549,125 @@ def test_remove_permutes_around_elemwise_ops_slice(self) -> None: args=(slice_copy, [0, 3, 1, 2]), ) builder.output([output]) - graph_module = builder.get_graph_module() + original = builder.get_graph_module() p = RemovePermutesAroundElementwiseOps() - graph_module = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module # No permutes should remain. self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0 + count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 ) # Verify that slice dimension was updated correctly. - slices = graph_module.graph.find_nodes( + slices = graph_after_passes.graph.find_nodes( op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor ) self.assertEqual(len(slices), 1) self.assertEqual(slices[0].args[1], 2) def test_remove_permutes_around_elemwise_ops_mul(self) -> None: - class M(torch.nn.Module): - def forward(self, x, y): - x = torch.slice_copy(x, 0, 0, 1) - x = torch.permute(x, [0, 3, 1, 2]) - y = torch.permute(y, [0, 3, 1, 2]) - x = torch.ops.quantized_decomposed.dequantize_per_tensor( - x, 1.5, 0, 0, 255, torch.uint8 - ) - z = x * y - z = torch.ops.quantized_decomposed.quantize_per_tensor( - z, 2.5, 0, 0, 255, torch.uint8 - ) - z = torch.permute(z, [0, 2, 3, 1]) - z = torch.unsqueeze_copy(z, 0) - return z - - inputs = (torch.randn(2, 4, 4, 8), torch.randn(2, 4, 4, 8)) - graph_module = export_to_edge(M(), inputs).exported_program().graph_module - + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 4, 4, 8)) + y = builder.placeholder("y", torch.randn(2, 4, 4, 8)) + sliced_x = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, 0, 0, 1), + ) + permuted_x = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(sliced_x, [0, 3, 1, 2]), + ) + permuted_y = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(y, [0, 3, 1, 2]), + ) + dequantized_x = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(permuted_x, 1.5, 0, 0, 255, torch.uint8), + ) + z = builder.call_operator( + op=exir_ops.edge.aten.mul.Tensor, args=(dequantized_x, permuted_y) + ) + quantized_z = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(z, 2.5, 0, 0, 255, torch.uint8), + ) + permuted_z = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(quantized_z, [0, 2, 3, 1]), + ) + output = builder.call_operator( + op=exir_ops.edge.aten.unsqueeze_copy.default, + args=(permuted_z, 0), + ) + builder.output([output]) + original = builder.get_graph_module() p = RemovePermutesAroundElementwiseOps() - graph_module = cast(PassResult, p(graph_module)).graph_module - + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0 + count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 ) def test_remove_permutes_around_elemwise_ops_double_permutes(self) -> None: - class M(torch.nn.Module): - def forward(self, x, y): - x = torch.slice_copy(x, 0, 0, 1) - x = torch.permute(x, [0, 3, 1, 2]) - x = torch.permute(x, [0, 3, 1, 2]) - x = torch.ops.quantized_decomposed.dequantize_per_tensor( - x, 1.5, 0, 0, 255, torch.uint8 - ) - y = torch.permute(y, [0, 3, 1, 2]) - y = torch.ops.quantized_decomposed.dequantize_per_tensor( - y, 1.5, 0, 0, 255, torch.uint8 - ) - z = torch.cat((x, y), 1) - z = torch.ops.quantized_decomposed.quantize_per_tensor( - z, 2.5, 0, 0, 255, torch.uint8 - ) - z = torch.permute(z, [0, 2, 3, 1]) - z = torch.permute(z, [0, 2, 3, 1]) - z = torch.unsqueeze_copy(z, 0) - return z - - inputs = (torch.randn(2, 4, 4, 8), torch.randn(1, 8, 4, 4)) - graph_module = export_to_edge(M(), inputs).exported_program().graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 4, 4, 8)) + y = builder.placeholder("y", torch.randn(1, 8, 4, 4)) + sliced_x = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, 0, 0, 1), + ) + permuted_x = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(sliced_x, [0, 3, 1, 2]), + ) + permuted_x = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(permuted_x, [0, 3, 1, 2]), + ) + dequantized_x = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(permuted_x, 1.5, 0, 0, 255, torch.uint8), + ) + permuted_y = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(y, [0, 3, 1, 2]), + ) + dequantized_y = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(permuted_y, 1.5, 0, 0, 255, torch.uint8), + ) + z = builder.call_operator( + op=exir_ops.edge.aten.cat.default, args=((dequantized_x, dequantized_y), 1) + ) + quantized_z = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(z, 2.5, 0, 0, 255, torch.uint8), + ) + permuted_z = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(quantized_z, [0, 2, 3, 1]), + ) + permuted_z = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(permuted_z, [0, 2, 3, 1]), + ) + output = builder.call_operator( + op=exir_ops.edge.aten.unsqueeze_copy.default, + args=(permuted_z, 0), + ) + builder.output([output]) + original = builder.get_graph_module() p = RemovePermutesAroundElementwiseOps() - graph_module = cast(PassResult, p(graph_module)).graph_module - + graph_after_passes = cast(PassResult, p(original)).graph_module # Expect 2 permutes to remain, one on input x and one on output z self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2 + count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 2 ) - # verify that cat was updated correctly cat = [ n - for n in graph_module.graph.nodes + for n in graph_after_passes.graph.nodes if n.target == exir_ops.edge.aten.cat.default ][0] self.assertEqual(cat.args[1], 3) @@ -692,111 +736,99 @@ def test_remove_permutes_around_elemwise_ops_complicated_case(self) -> None: count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 4 ) - def test_remove_permutes_around_elemwise_ops_noop(self) -> None: - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(8, 8, 1, bias=False) - - def forward(self, x): - x = self.conv(x) - x = torch.permute(x, [2, 1, 0, 3]) - x = torch.add(x, x) - x = torch.permute(x, [0, 1, 3, 2]) - x = self.conv(x) - return x - - inputs = (torch.randn(1, 8, 4, 4),) - graph_module = export_to_edge(M(), inputs).exported_program().graph_module - p = RemovePermutesAroundElementwiseOps() - graph_module = cast(PassResult, p(graph_module)).graph_module - - # Ensure no permutes were removed, since the dimensions don't fit the expected pattern + def test_remove_dequant_on_branch(self): + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 8, 4, 6)) + x = builder.call_operator(op=exir_ops.edge.aten.abs.default, args=(x,)) + x0 = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(x, 1.2, 3, 0, 127, torch.int8), + ) + x1_output = builder.call_operator(op=exir_ops.edge.aten.abs.default, args=(x0,)) + y0 = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(x0, 1.2, 3, 0, 127, torch.int8), + ) + y1_output = builder.call_operator( + op=exir_ops.edge.aten.view.default, + args=(y0, [-1]), + ) + builder.output([x1_output, y1_output]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveBranchedQuantDequant()(original) + ).graph_module self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2 + count_node( + graph_after_passes, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + ), + 1, + ) + self.assertEqual( + count_node( + graph_after_passes, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + ), + 0, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.abs.default), 2 ) - def test_remove_dequant_on_branch(self): - class M(torch.nn.Module): - def forward(self, x): - x = torch.abs(x) - x0 = torch.ops.quantized_decomposed.quantize_per_tensor( - x, 1.2, 3, 0, 127, torch.int8 - ) - x1 = torch.abs(x0) - y0 = torch.ops.quantized_decomposed.dequantize_per_tensor( - x0, 1.2, 3, 0, 127, torch.int8 - ) - y1 = y0.view(-1) - return x1, y1 - - inputs = torch.rand(1, 8, 4, 6) - model = M() - graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module - - graph_module = RemoveBranchedQuantDequant()(graph_module).graph_module - self.assertTrue( - op_counts_match( - graph_module, - expected_op_counts={ - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, - # we expect the pass to remove the dequantize node - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, - exir_ops.edge.aten.abs.default: 2, - }, - ) - ) - - def test_remove_cat_from_slice_copy_all_removal(self) -> None: - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - x1 = torch.cat((x, y), 0) # (2, 4) - return torch.slice_copy(x1, dim=0, start=0, end=1) - - inputs = tuple(torch.randn(2, 4) for _ in range(2)) - graph_module = export_to_edge(M(), inputs).exported_program().graph_module - p = RemoveCatFromSliceCopyPass() - graph_module = cast(PassResult, p(graph_module)).graph_module - - # Ensure both cat nodes were removed - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0) - - def test_remove_cat_from_slice_copy_no_removal(self) -> None: - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - x1 = torch.cat((x, y), 0) # (2, 4) - return torch.slice_copy(x1, dim=0, start=0, end=3) - - inputs = tuple(torch.randn(2, 4) for _ in range(2)) - graph_module = export_to_edge(M(), inputs).exported_program().graph_module - p = RemoveCatFromSliceCopyPass() - graph_module = cast(PassResult, p(graph_module)).graph_module + def test_remove_cat_from_slice_copy(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 4)) + y = builder.placeholder("y", torch.randn(2, 4)) + z = builder.call_operator(op=exir_ops.edge.aten.cat.default, args=((x, y), 0)) + output = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(z, 0, 0, 1), + ) + builder.output([output]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveCatFromSliceCopyPass()(original) + ).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0 + ) - # Ensure both cat nodes were removed - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1) + def test_keep_cat_from_slice_copy(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 4)) + y = builder.placeholder("y", torch.randn(2, 4)) + z = builder.call_operator(op=exir_ops.edge.aten.cat.default, args=((x, y), 0)) + output = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(z, 0, 0, 3), + ) + builder.output([output]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveCatFromSliceCopyPass()(original) + ).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 1 + ) def test_remove_cat_from_slice_copy_zero_range(self) -> None: - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - x1 = torch.cat((x, y), 0) # (2, 4) - return torch.slice_copy(x1, dim=0, start=0, end=0) - - inputs = tuple(torch.randn(2, 4) for _ in range(2)) - graph_module = export_to_edge(M(), inputs).exported_program().graph_module - p = RemoveCatFromSliceCopyPass() - graph_module = cast(PassResult, p(graph_module)).graph_module - - # Ensure both cat nodes were removed - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0) + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 4)) + y = builder.placeholder("y", torch.randn(2, 4)) + z = builder.call_operator(op=exir_ops.edge.aten.cat.default, args=((x, y), 0)) + output = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(z, 0, 0, 0), + ) + builder.output([output]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveCatFromSliceCopyPass()(original) + ).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0 + ) def test_remove_cat_from_slice_copy_second_input(self) -> None: builder = GraphBuilder() @@ -808,7 +840,7 @@ def test_remove_cat_from_slice_copy_second_input(self) -> None: ) slice_copy = builder.call_operator( op=exir_ops.edge.aten.slice_copy.Tensor, - args=(cat, 1, 5, 7, 1), + args=(cat, 1, 5, 7, 1), # dim start end step ) builder.output([slice_copy]) graph_module = builder.get_graph_module()