diff --git a/backends/qualcomm/_passes/remove_redundancy.py b/backends/qualcomm/_passes/remove_redundancy.py index 825b2584ca7..2b14aed6c7f 100644 --- a/backends/qualcomm/_passes/remove_redundancy.py +++ b/backends/qualcomm/_passes/remove_redundancy.py @@ -11,7 +11,7 @@ class RemoveRedundancy(ExportPass): """ - Trim the 'identity' operators to reduce the unnecessary copy overhead. + Trim certain operators to reduce unnecessary overhead. """ redundant_ops = { @@ -21,6 +21,10 @@ class RemoveRedundancy(ExportPass): torch.ops.aten.alias.default, exir_ops.edge.aten.alias.default, exir_ops.edge.aten.lift_fresh_copy.default, + # remove this target if '_skip_dim_order' is set to False + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + # remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True + exir_ops.edge.aten._to_copy.default, } def __init__(self): @@ -31,6 +35,13 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: if n.target not in self.redundant_ops: continue + # do not remove cast operator + if ( + n.target == exir_ops.edge.aten._to_copy.default + and "memory_format" not in n.kwargs + ): + continue + to_be_remove = n for user_n in list(n.users.keys()): user_n.replace_input_with(n, n.args[0]) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 0ed66329c33..b007400318b 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -325,7 +325,7 @@ def forward(self, x): class Conv2dSequential(torch.nn.Module): - def __init__(self, bias=True): + def __init__(self, bias=True, channel_last=False): super().__init__() self.first = torch.nn.Conv2d( in_channels=1, @@ -341,8 +341,10 @@ def __init__(self, bias=True): padding=1, bias=bias, ) + self.channel_last = channel_last def forward(self, x): + x = x.to(memory_format=torch.channels_last) if self.channel_last else x return self.second(self.first(x)) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 10917cdd6bf..f81a9dac6ca 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -133,6 +133,16 @@ def test_qnn_backend_conv2d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_channel_last(self): + modules = [ + Conv2dSequential(channel_last=True), # noqa: F405 + Conv2dSequential(bias=False, channel_last=True), # noqa: F405 + ] + sample_input = (torch.randn([1, 1, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv_transpose2d(self): modules = [ ConvTranspose2dSingle(), # noqa: F405 @@ -814,6 +824,17 @@ def test_qnn_backend_conv2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_channel_last(self): + modules = [ + Conv2dSequential(channel_last=True), # noqa: F405 + Conv2dSequential(bias=False, channel_last=True), # noqa: F405 + ] + sample_input = (torch.randn([1, 1, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv_transpose2d(self): modules = [ ConvTranspose2dSingle(), # noqa: F405 diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 590ede74319..33be00ed51d 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -166,7 +166,6 @@ def qnn_capture_config(): def qnn_edge_config() -> exir.EdgeCompileConfig: return exir.EdgeCompileConfig( _check_ir_validity=False, - _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. )