From 3bc9fddef5f9d3e9ee0deccae2fd3eff8b8db994 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 10 May 2024 13:53:54 -0700 Subject: [PATCH] don't partition max pool with ceil mode (#3578) Summary: XNNPACK doesn't support max pooling with ceil mode, so we should not be partitioning these nodes where ceil mode is True Resolving this issue: https://github.com/pytorch/executorch/issues/3567 Differential Revision: D57228128 --- .../xnnpack/partition/xnnpack_partitioner.py | 5 +++ backends/xnnpack/test/ops/maxpool2d.py | 36 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index f5b11a631a3..6d483e4ea00 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -166,6 +166,8 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): return True def check_node_has_valid_dtype(self, node): + # max_pool2d_with_indicies returns indicies which is int64 + # this is supportable within XNNPACK if node.target in {exir_ops.edge.aten.max_pool2d_with_indices.default}: return True @@ -268,13 +270,16 @@ def maxpool2d_with_indices( ) -> bool: """ Only if the first output value is consumed in the graph + and it is not in ceil mode """ users = list(node.users.keys()) + is_ceil_mode = len(node.args) >= 6 and node.args[5] return ( True if len(users) == 1 and users[0].target == operator.getitem and users[0].args[1] == 0 + and not is_ceil_mode else False ) diff --git a/backends/xnnpack/test/ops/maxpool2d.py b/backends/xnnpack/test/ops/maxpool2d.py index e919fc6e776..889c29a5f38 100644 --- a/backends/xnnpack/test/ops/maxpool2d.py +++ b/backends/xnnpack/test/ops/maxpool2d.py @@ -38,6 +38,14 @@ def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1): def forward(self, x): return self.max_pool2d_module(x)[1] + class MaxPool2dUnsupportedCeilMode(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + + def forward(self, x): + return self.max_pool2d_module(x) + def _test_maxpool2d(self, inputs): """ Note that the export process generates aten.max_pool2d_with_indices. The remove_getitem_op @@ -99,6 +107,34 @@ def test_fp32_maxpool2d_unsupported(self): ) ) + def test_fp32_maxpool2d_unsupported_ceilmode(self): + """ + MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint). + """ + inputs = (torch.randn(1, 32, 23, 23),) + ( + Tester(self.MaxPool2dUnsupportedCeilMode(), inputs) + .export() + .check_count({"torch.ops.aten.max_pool2d_with_indices.default": 1}) + .to_edge() + .check_count( + { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1 + } + ) + .partition() + # We expect it not be be delegated. + .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) + .check_count( + { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1 + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + def test_qs8_maxpool2d(self): class MaxPool(torch.nn.Module): def __init__(self, maxpool_params):