From f2f2932545ef28a66d09324a820ab8cf8a218b2a Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Tue, 29 Jul 2025 12:23:59 -0700 Subject: [PATCH 1/8] Update clone removal transform to be dim order aware; add tests --- backends/transforms/remove_clone_ops.py | 17 ++++++- exir/tests/test_memory_format_ops_pass.py | 56 +++++++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/backends/transforms/remove_clone_ops.py b/backends/transforms/remove_clone_ops.py index 2751dee2816..67913aae02a 100644 --- a/backends/transforms/remove_clone_ops.py +++ b/backends/transforms/remove_clone_ops.py @@ -13,11 +13,24 @@ def remove_clone_ops(graph: torch.fx.Graph) -> torch.fx.Graph: """ - Remove clone op nodes and replace uses with parent node. + Remove clone op nodes that have the same dim_order as their input, and replace their uses with the input node. """ clone_op = exir_ops.edge.aten.clone.default + clone_dim_order_op = exir_ops.edge.dim_order_ops._clone_dim_order.default + for node in graph.nodes: - if node.op == "call_function" and node.target == clone_op: + if node.op != "call_function": + continue + + # Identify clone_dim_order ops with unchanged memory layout. + unchanged_layout_clone = ( + node.target == clone_dim_order_op + and "val" in node.meta + and "val" in node.args[0].meta + and node.meta["val"].dim_order() == node.args[0].meta["val"].dim_order() + ) + + if node.target == clone_op or unchanged_layout_clone: with graph.inserting_after(node): node.replace_all_uses_with(node.args[0]) diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 84cd0faa485..b084c3d266a 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -12,6 +12,7 @@ import torch import torchvision +from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir import EdgeCompileConfig, to_edge from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -22,12 +23,14 @@ is_contiguous_dim_order, ) from executorch.exir.pass_base import ExportPass, ProxyValue +from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass from executorch.exir.tests.test_memory_format_ops_pass_utils import ( AmbiguousDimOrderError, MemoryFormatOpsPassTestUtils, MemoryFormatTestSet, PropagateToCopyChannalsLastModule, + SimpleCloneChannelsLastModule, SimpleEmptyChannelLastModule, SimpleEmptyContiguoustModule, SimpleToCopyChannelsLastModule, @@ -389,3 +392,56 @@ def test_mobilenet_v3_xnnpack(self) -> None: rtol=1e-3, ), ) + + def test_op_clone_replacement_channels_last_survives(self): + clone_dim_order_op_str = ( + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ) + + model = SimpleCloneChannelsLastModule() + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) + + exported = export(model.eval(), (x,), strict=True) + before_epm = to_edge( + exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) + ) + + updated_epm = before_epm.transform([MemoryFormatOpsPass()]) + updated_epm = updated_epm.transform([RemoveCloneOpsTransform()]) + + FileCheck().check_count(clone_dim_order_op_str, 1, exactly=True).run( + updated_epm.exported_program().graph_module.code + ) + + expected = before_epm.exported_program().module()(x) + actual = updated_epm.exported_program().module()(x) + assert torch.allclose(actual, expected) + assert is_channel_last_dim_order(actual) + + def test_op_clone_without_transformation_removed(self): + clone_dim_order_op_str = ( + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ) + + model = SimpleCloneChannelsLastModule() + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) + + exported = export(model.eval(), (x,), strict=True) + before_epm = to_edge( + exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) + ) + + updated_epm = before_epm.transform([MemoryFormatOpsPass()]) + FileCheck().check_count(clone_dim_order_op_str, 1, exactly=True).run( + updated_epm.exported_program().graph_module.code + ) + + updated_epm = updated_epm.transform([RemoveCloneOpsTransform()]) + FileCheck().check_not(clone_dim_order_op_str).run( + updated_epm.exported_program().graph_module.code + ) + + expected = before_epm.exported_program().module()(x) + actual = updated_epm.exported_program().module()(x) + assert torch.allclose(actual, expected) + assert is_channel_last_dim_order(actual) From ad74bdf572246a211c81795621780d5376ef2476 Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Sat, 2 Aug 2025 21:45:00 -0700 Subject: [PATCH 2/8] Remove explicit MemoryFormatOpsPass transform from clone_dim_order tests --- exir/tests/test_memory_format_ops_pass.py | 107 +++++++++++----------- 1 file changed, 52 insertions(+), 55 deletions(-) diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index b084c3d266a..90a1bdf3c55 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -23,14 +23,12 @@ is_contiguous_dim_order, ) from executorch.exir.pass_base import ExportPass, ProxyValue -from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass from executorch.exir.tests.test_memory_format_ops_pass_utils import ( AmbiguousDimOrderError, MemoryFormatOpsPassTestUtils, MemoryFormatTestSet, PropagateToCopyChannalsLastModule, - SimpleCloneChannelsLastModule, SimpleEmptyChannelLastModule, SimpleEmptyContiguoustModule, SimpleToCopyChannelsLastModule, @@ -327,6 +325,58 @@ def call_operator(self, op, args, kwargs, meta): self.assertTrue(is_contiguous_dim_order(actual)) self.assertTrue(is_contiguous_dim_order(expected)) + def test_op_clone_replacement_channels_last_survives(self): + _clone_dim_order_op_str = ( + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ) + + model = SimpleCloneChannelsLastModule() + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) + + exported = export(model.eval(), (x,), strict=True) + before_epm = to_edge( + exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) + ) + + updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) + + FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run( + updated_epm.exported_program().graph_module.code + ) + + expected = before_epm.exported_program().module()(x) + actual = updated_epm.exported_program().module()(x) + assert torch.allclose(actual, expected) + assert is_channel_last_dim_order(actual) + + def test_op_clone_without_transformation_removed(self): + _clone_dim_order_op_str = ( + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ) + + model = SimpleCloneChannelsLastModule() + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) + + exported = export(model.eval(), (x,), strict=True) + before_epm = to_edge( + exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) + ) + + FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run( + before_epm.exported_program().graph_module.code + ) + + updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) + + FileCheck().check_not(_clone_dim_order_op_str).run( + updated_epm.exported_program().graph_module.code + ) + + expected = before_epm.exported_program().module()(x) + actual = updated_epm.exported_program().module()(x) + assert torch.allclose(actual, expected) + assert is_channel_last_dim_order(actual) + def test_resnet18(self) -> None: model = torchvision.models.resnet18() MemoryFormatOpsPassTestUtils.memory_format_test_runner( @@ -392,56 +442,3 @@ def test_mobilenet_v3_xnnpack(self) -> None: rtol=1e-3, ), ) - - def test_op_clone_replacement_channels_last_survives(self): - clone_dim_order_op_str = ( - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" - ) - - model = SimpleCloneChannelsLastModule() - x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) - - exported = export(model.eval(), (x,), strict=True) - before_epm = to_edge( - exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) - ) - - updated_epm = before_epm.transform([MemoryFormatOpsPass()]) - updated_epm = updated_epm.transform([RemoveCloneOpsTransform()]) - - FileCheck().check_count(clone_dim_order_op_str, 1, exactly=True).run( - updated_epm.exported_program().graph_module.code - ) - - expected = before_epm.exported_program().module()(x) - actual = updated_epm.exported_program().module()(x) - assert torch.allclose(actual, expected) - assert is_channel_last_dim_order(actual) - - def test_op_clone_without_transformation_removed(self): - clone_dim_order_op_str = ( - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" - ) - - model = SimpleCloneChannelsLastModule() - x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) - - exported = export(model.eval(), (x,), strict=True) - before_epm = to_edge( - exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) - ) - - updated_epm = before_epm.transform([MemoryFormatOpsPass()]) - FileCheck().check_count(clone_dim_order_op_str, 1, exactly=True).run( - updated_epm.exported_program().graph_module.code - ) - - updated_epm = updated_epm.transform([RemoveCloneOpsTransform()]) - FileCheck().check_not(clone_dim_order_op_str).run( - updated_epm.exported_program().graph_module.code - ) - - expected = before_epm.exported_program().module()(x) - actual = updated_epm.exported_program().module()(x) - assert torch.allclose(actual, expected) - assert is_channel_last_dim_order(actual) From ffd15493c0d5ede1df77b753206c10c928f9df9f Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Sat, 16 Aug 2025 00:45:17 -0700 Subject: [PATCH 3/8] Add aten.clone memory_format check in RemoveCloneOpsTransform --- backends/transforms/remove_clone_ops.py | 41 ++++++++--- exir/tests/test_memory_format_ops_pass.py | 90 +++++++++++++---------- 2 files changed, 82 insertions(+), 49 deletions(-) diff --git a/backends/transforms/remove_clone_ops.py b/backends/transforms/remove_clone_ops.py index 67913aae02a..166268bbcc7 100644 --- a/backends/transforms/remove_clone_ops.py +++ b/backends/transforms/remove_clone_ops.py @@ -15,22 +15,11 @@ def remove_clone_ops(graph: torch.fx.Graph) -> torch.fx.Graph: """ Remove clone op nodes that have the same dim_order as their input, and replace their uses with the input node. """ - clone_op = exir_ops.edge.aten.clone.default - clone_dim_order_op = exir_ops.edge.dim_order_ops._clone_dim_order.default - for node in graph.nodes: if node.op != "call_function": continue - # Identify clone_dim_order ops with unchanged memory layout. - unchanged_layout_clone = ( - node.target == clone_dim_order_op - and "val" in node.meta - and "val" in node.args[0].meta - and node.meta["val"].dim_order() == node.args[0].meta["val"].dim_order() - ) - - if node.target == clone_op or unchanged_layout_clone: + if is_unchanged_clone(node) or is_unchanged_dim_order_clone(node): with graph.inserting_after(node): node.replace_all_uses_with(node.args[0]) @@ -38,6 +27,34 @@ def remove_clone_ops(graph: torch.fx.Graph) -> torch.fx.Graph: return graph +def is_unchanged_clone(node: torch.fx.Node) -> bool: + """Determine if aten.clone has unchanged memory format.""" + if node.target != exir_ops.edge.aten.clone.default: + return False + + memory_format = node.kwargs.get("memory_format") + if memory_format in (None, torch.preserve_format): + return True + + input_meta = node.args[0].meta + return "val" in input_meta and input_meta["val"].is_contiguous( + memory_format=memory_format + ) + + +def is_unchanged_dim_order_clone(node: torch.fx.Node) -> bool: + """Determine if _clone_dim_order has unchanged dim order.""" + if node.target != exir_ops.edge.dim_order_ops._clone_dim_order.default: + return False + + input_meta = node.args[0].meta + return ( + "val" in node.meta + and "val" in input_meta + and node.meta["val"].dim_order() == input_meta["val"].dim_order() + ) + + class RemoveCloneOpsTransform(ExportPass): def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph_module.graph = remove_clone_ops(graph_module.graph) diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 90a1bdf3c55..2b73b66981e 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -326,56 +326,72 @@ def call_operator(self, op, args, kwargs, meta): self.assertTrue(is_contiguous_dim_order(expected)) def test_op_clone_replacement_channels_last_survives(self): - _clone_dim_order_op_str = ( - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" - ) + clone_op_cases = [ + # Case testing aten.clone by setting _skip_dim_order to True + (True, "executorch_exir_dialects_edge__ops_aten_clone_default"), + # Case testing _clone_dim_order by setting _skip_dim_order to False + ( + False, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", + ), + ] - model = SimpleCloneChannelsLastModule() - x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) + for skip_dim_order, clone_op_str in clone_op_cases: + model = SimpleCloneChannelsLastModule() + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) - exported = export(model.eval(), (x,), strict=True) - before_epm = to_edge( - exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) - ) + exported = export(model.eval(), (x,), strict=True) + before_epm = to_edge( + exported, + compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), + ) - updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) + updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) - FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run( - updated_epm.exported_program().graph_module.code - ) + FileCheck().check_count(clone_op_str, 1, exactly=True).run( + updated_epm.exported_program().graph_module.code + ) - expected = before_epm.exported_program().module()(x) - actual = updated_epm.exported_program().module()(x) - assert torch.allclose(actual, expected) - assert is_channel_last_dim_order(actual) + expected = before_epm.exported_program().module()(x) + actual = updated_epm.exported_program().module()(x) + assert torch.allclose(actual, expected) + assert is_channel_last_dim_order(actual) def test_op_clone_without_transformation_removed(self): - _clone_dim_order_op_str = ( - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" - ) + clone_op_cases = [ + # Case testing aten.clone by setting _skip_dim_order to True + (True, "executorch_exir_dialects_edge__ops_aten_clone_default"), + # Case testing _clone_dim_order by setting _skip_dim_order to False + ( + False, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", + ), + ] - model = SimpleCloneChannelsLastModule() - x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) + for skip_dim_order, clone_op_str in clone_op_cases: + model = SimpleCloneChannelsLastModule() + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) - exported = export(model.eval(), (x,), strict=True) - before_epm = to_edge( - exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) - ) + exported = export(model.eval(), (x,), strict=True) + before_epm = to_edge( + exported, + compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), + ) - FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run( - before_epm.exported_program().graph_module.code - ) + FileCheck().check_count(clone_op_str, 1, exactly=True).run( + before_epm.exported_program().graph_module.code + ) - updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) + updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) - FileCheck().check_not(_clone_dim_order_op_str).run( - updated_epm.exported_program().graph_module.code - ) + FileCheck().check_not(clone_op_str).run( + updated_epm.exported_program().graph_module.code + ) - expected = before_epm.exported_program().module()(x) - actual = updated_epm.exported_program().module()(x) - assert torch.allclose(actual, expected) - assert is_channel_last_dim_order(actual) + expected = before_epm.exported_program().module()(x) + actual = updated_epm.exported_program().module()(x) + assert torch.allclose(actual, expected) + assert is_channel_last_dim_order(actual) def test_resnet18(self) -> None: model = torchvision.models.resnet18() From b8485bc7de2cdf361b93a293293a7327f40c045a Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Thu, 4 Sep 2025 20:47:07 -0700 Subject: [PATCH 4/8] Refactor clone identity check into _is_non_identity_clone --- backends/transforms/remove_clone_ops.py | 53 +++++++++++-------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/backends/transforms/remove_clone_ops.py b/backends/transforms/remove_clone_ops.py index a61c84ad49e..01fe2ee26a4 100644 --- a/backends/transforms/remove_clone_ops.py +++ b/backends/transforms/remove_clone_ops.py @@ -35,10 +35,7 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None: if n.target not in self.clone_ops: continue - # Skip removal of clone ops that modify layout/dim order. - if self.aten_clone_is_non_identity( - n - ) or self._clone_dim_order_is_non_identity(n): + if self._is_non_identity_clone(n): continue to_be_removed = n @@ -56,28 +53,26 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: dead_code_elimination_pass(graph_module) return PassResult(graph_module, True) - def aten_clone_is_non_identity(self, node: torch.fx.Node) -> bool: - """Return True if aten.clone has modified memory format.""" - if node.target != exir_ops.edge.aten.clone.default: - return False - - memory_format = node.kwargs.get("memory_format") - if memory_format in (None, torch.preserve_format): - return False - - input_meta = node.args[0].meta - return "val" in input_meta and not input_meta["val"].is_contiguous( - memory_format=memory_format - ) - - def _clone_dim_order_is_non_identity(self, node: torch.fx.Node) -> bool: - """Return True if _clone_dim_order has modified dim order.""" - if node.target != exir_ops.edge.dim_order_ops._clone_dim_order.default: - return False - - input_meta = node.args[0].meta - return ( - "val" in node.meta - and "val" in input_meta - and node.meta["val"].dim_order() != input_meta["val"].dim_order() - ) + def _is_non_identity_clone(self, node: torch.fx.Node) -> bool: + """Return True if clone has modified memory layout or dim order.""" + + # aten.clone: check for memory_format changes + if node.target == exir_ops.edge.aten.clone.default: + memory_format = node.kwargs.get("memory_format") + if memory_format in (None, torch.preserve_format): + return False + input_meta = node.args[0].meta + return "val" in input_meta and not input_meta["val"].is_contiguous( + memory_format=memory_format + ) + + # _clone_dim_order: check for dim_order changes + if node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default: + input_meta = node.args[0].meta + return ( + "val" in node.meta + and "val" in input_meta + and node.meta["val"].dim_order() != input_meta["val"].dim_order() + ) + + return False From e1338984840a4dc0a6391dffe4a8417d06978b7a Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Thu, 4 Sep 2025 20:50:16 -0700 Subject: [PATCH 5/8] Move clone tests from test_memory_format_ops_pass to test_remove_clone_ops --- .../transforms/test/test_remove_clone_ops.py | 69 +++++++++++++++++++ exir/tests/test_memory_format_ops_pass.py | 69 ------------------- 2 files changed, 69 insertions(+), 69 deletions(-) diff --git a/backends/transforms/test/test_remove_clone_ops.py b/backends/transforms/test/test_remove_clone_ops.py index 5d7a1ecd59f..cd7e834c461 100644 --- a/backends/transforms/test/test_remove_clone_ops.py +++ b/backends/transforms/test/test_remove_clone_ops.py @@ -8,13 +8,30 @@ import torch from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform +from executorch.exir import EdgeCompileConfig, to_edge from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dim_order_utils import is_channel_last_dim_order +from executorch.exir.tests.test_memory_format_ops_pass_utils import ( + SimpleCloneChannelsLastModule, +) +from torch.export import export from torch.fx import GraphModule from torch.testing import FileCheck from torch.testing._internal.common_utils import TestCase class TestRemoveCloneOpsTransform(TestCase): + # Clone ops can appear as either aten.clone or _clone_dim_order depending on the _skip_dim_order flag. + # _skip_dim_order=True tests aten.clone + # _skip_dim_order=False tests _clone_dim_order. + CLONE_OP_CASES = [ + (True, "executorch_exir_dialects_edge__ops_aten_clone_default"), + ( + False, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", + ), + ] + def test_dq_clone_q_linear(self): """ Test RemoveCloneOpsTransform on a graph with d/q -> clone -> q -> linear pattern @@ -123,6 +140,58 @@ def forward(self, x): transformed_gm.code ) + def test_clone_channels_last_survives(self): + """Verify clone ops that modify memory_format are preserved by RemoveCloneOpsTransform.""" + + for skip_dim_order, clone_op_str in self.CLONE_OP_CASES: + model = SimpleCloneChannelsLastModule() + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) + + exported = export(model.eval(), (x,), strict=True) + before_epm = to_edge( + exported, + compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), + ) + + updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) + + FileCheck().check_count(clone_op_str, 1, exactly=True).run( + updated_epm.exported_program().graph_module.code + ) + + expected = before_epm.exported_program().module()(x) + actual = updated_epm.exported_program().module()(x) + assert torch.allclose(actual, expected) + assert is_channel_last_dim_order(actual) + + def test_clone_identity_removed(self): + """Verify identity clone ops are removed by RemoveCloneOpsTransform.""" + + for skip_dim_order, clone_op_str in self.CLONE_OP_CASES: + model = SimpleCloneChannelsLastModule() + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) + + exported = export(model.eval(), (x,), strict=True) + before_epm = to_edge( + exported, + compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), + ) + + FileCheck().check_count(clone_op_str, 1, exactly=True).run( + before_epm.exported_program().graph_module.code + ) + + updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) + + FileCheck().check_not(clone_op_str).run( + updated_epm.exported_program().graph_module.code + ) + + expected = before_epm.exported_program().module()(x) + actual = updated_epm.exported_program().module()(x) + assert torch.allclose(actual, expected) + assert is_channel_last_dim_order(actual) + if __name__ == "__main__": unittest.main() diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index e1be2dfbd43..2384f6123a9 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -12,7 +12,6 @@ import torch import torchvision -from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir import EdgeCompileConfig, to_edge from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -377,74 +376,6 @@ def call_operator(self, op, args, kwargs, meta): self.assertTrue(is_contiguous_dim_order(actual)) self.assertTrue(is_contiguous_dim_order(expected)) - def test_op_clone_replacement_channels_last_survives(self): - clone_op_cases = [ - # Case testing aten.clone by setting _skip_dim_order to True - (True, "executorch_exir_dialects_edge__ops_aten_clone_default"), - # Case testing _clone_dim_order by setting _skip_dim_order to False - ( - False, - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", - ), - ] - - for skip_dim_order, clone_op_str in clone_op_cases: - model = SimpleCloneChannelsLastModule() - x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) - - exported = export(model.eval(), (x,), strict=True) - before_epm = to_edge( - exported, - compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), - ) - - updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) - - FileCheck().check_count(clone_op_str, 1, exactly=True).run( - updated_epm.exported_program().graph_module.code - ) - - expected = before_epm.exported_program().module()(x) - actual = updated_epm.exported_program().module()(x) - assert torch.allclose(actual, expected) - assert is_channel_last_dim_order(actual) - - def test_op_clone_without_transformation_removed(self): - clone_op_cases = [ - # Case testing aten.clone by setting _skip_dim_order to True - (True, "executorch_exir_dialects_edge__ops_aten_clone_default"), - # Case testing _clone_dim_order by setting _skip_dim_order to False - ( - False, - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", - ), - ] - - for skip_dim_order, clone_op_str in clone_op_cases: - model = SimpleCloneChannelsLastModule() - x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) - - exported = export(model.eval(), (x,), strict=True) - before_epm = to_edge( - exported, - compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), - ) - - FileCheck().check_count(clone_op_str, 1, exactly=True).run( - before_epm.exported_program().graph_module.code - ) - - updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) - - FileCheck().check_not(clone_op_str).run( - updated_epm.exported_program().graph_module.code - ) - - expected = before_epm.exported_program().module()(x) - actual = updated_epm.exported_program().module()(x) - assert torch.allclose(actual, expected) - assert is_channel_last_dim_order(actual) - def test_resnet18(self) -> None: model = torchvision.models.resnet18() MemoryFormatOpsPassTestUtils.memory_format_test_runner( From 17f2e6cdac08e3df1168790a5298bb8c93cef1dd Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Thu, 4 Sep 2025 21:04:31 -0700 Subject: [PATCH 6/8] Change clone test name to test_clone_non_identity_survives --- backends/transforms/test/test_remove_clone_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/transforms/test/test_remove_clone_ops.py b/backends/transforms/test/test_remove_clone_ops.py index cd7e834c461..e5b709caee2 100644 --- a/backends/transforms/test/test_remove_clone_ops.py +++ b/backends/transforms/test/test_remove_clone_ops.py @@ -140,7 +140,7 @@ def forward(self, x): transformed_gm.code ) - def test_clone_channels_last_survives(self): + def test_clone_non_identity_survives(self): """Verify clone ops that modify memory_format are preserved by RemoveCloneOpsTransform.""" for skip_dim_order, clone_op_str in self.CLONE_OP_CASES: From 4b68e11538d5eb5920771e9bb955160a543a4cce Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Fri, 5 Sep 2025 10:55:44 -0700 Subject: [PATCH 7/8] Add test_remove_clone_ops to pytest.ini config --- backends/transforms/test/test_remove_clone_ops.py | 2 +- pytest.ini | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/backends/transforms/test/test_remove_clone_ops.py b/backends/transforms/test/test_remove_clone_ops.py index e5b709caee2..d34c522baaa 100644 --- a/backends/transforms/test/test_remove_clone_ops.py +++ b/backends/transforms/test/test_remove_clone_ops.py @@ -23,7 +23,7 @@ class TestRemoveCloneOpsTransform(TestCase): # Clone ops can appear as either aten.clone or _clone_dim_order depending on the _skip_dim_order flag. # _skip_dim_order=True tests aten.clone - # _skip_dim_order=False tests _clone_dim_order. + # _skip_dim_order=False tests _clone_dim_order CLONE_OP_CASES = [ (True, "executorch_exir_dialects_edge__ops_aten_clone_default"), ( diff --git a/pytest.ini b/pytest.ini index aae87f242a7..a16d52a4283 100644 --- a/pytest.ini +++ b/pytest.ini @@ -50,6 +50,8 @@ addopts = --ignore=backends/test backends/test/harness/tests backends/test/suite/tests + # backends/transforms + backends/transforms/test/test_remove_clone_ops.py # backends/xnnpack backends/xnnpack/test/ops --ignore=backends/xnnpack/test/ops/test_bmm.py From ffa31015c531db36fcdf1bfc475e4b91a5145417 Mon Sep 17 00:00:00 2001 From: Zuby Afzal <65686164+keyprocedure@users.noreply.github.com> Date: Mon, 8 Sep 2025 14:59:24 -0700 Subject: [PATCH 8/8] Format pytest.ini --- pytest.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytest.ini b/pytest.ini index 3ab61cf7d2d..3a97b72d504 100644 --- a/pytest.ini +++ b/pytest.ini @@ -115,4 +115,4 @@ addopts = # run the same tests multiple times to determine their # flakiness status. Default to 50 re-runs flake-finder = true -flake-runs = 50 \ No newline at end of file +flake-runs = 50