Skip to content

XNNPack Fails For nn.MaxPool2d #3567

@kinghchan

Description

@kinghchan

Hi team,

Here is a minimal example to show the issue with nn.MaxPool2d and XNNPack:

class Wrapper(nn.Module):
    def __init__(self):
        super(Wrapper, self).__init__()
    def forward(self, input):
        m = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        output = m(input)
        return output

When I lower to XNNPack backend like this (with example inputs shown):

wrapper = Wrapper()
wrapper.eval()

input = torch.randn((1,32,23,23))
result = wrapper(input)

pre_autograd_aten_dialect = capture_pre_autograd_graph(wrapper, (input,))
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, (input,))
edge_program: EdgeProgramManager = to_edge(aten_dialect)

##### Lowering to XNNPACK
xnnpack_lowered_module = edge_program.to_backend(XnnpackPartitioner())
exec_prog = xnnpack_lowered_module.to_executorch()

with open("xnnpack_lowered_module.pte", "wb") as f:
    exec_prog.write_to_file(f)

With this Edge Graph output:

class GraphModule(
    def forward(self, arg0_1: "f32[1, 32, 23, 23]"):
            aten_max_pool2d_with_indices_default = executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default(arg0_1, [2, 2], [2, 2], [0, 0], [1, 1], True);  arg0_1 = None
            getitem: "f32[1, 32, 12, 12]" = aten_max_pool2d_with_indices_default[0];  aten_max_pool2d_with_indices_default = None
            return (getitem,)

And this XNNPack backend graph output:

class GraphModule(torch.nn.Module):
      def forward(self, arg0_1: "f32[1, 32, 23, 23]"):
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, arg0_1);  lowered_module_0 = arg0_1 = None
            getitem: "f32[1, 32, 12, 12]" = executorch_call_delegate[0];  executorch_call_delegate = None
            return (getitem,)

And run it with xnnpack_executor_runner:

./xnn_executor_runner --model_path xnnpack_lowered_module.pte

This errors out during runtime, with these errors/logs:

I 00:00:00.001714 executorch:executor_runner.cpp:83] Using method forward
I 00:00:00.001717 executorch:executor_runner.cpp:130] Setting up planned buffer 0, size 86144.
I 00:00:00.015819 executorch:executor_runner.cpp:159] Load duration = 14.085500
I 00:00:00.015840 executorch:executor_runner.cpp:166] Method loaded.
I 00:00:00.015951 executorch:executor_runner.cpp:176] Inputs prepared.
E 00:00:00.046374 executorch:tensor_impl.cpp:151] Attempted to resize a static tensor to a new shape at dimension 2 old_size: 12 new_size: 11
E 00:00:00.046380 executorch:XNNExecutor.cpp:203] Failed to resize output tensor for XNNExecutor
E 00:00:00.046382 executorch:method.cpp:1072] CALL_DELEGATE execute failed at instruction 0: 0x10
I 00:00:00.046385 executorch:executor_runner.cpp:182] Inference latency=0.00ms.
F 00:00:00.046398 executorch:executor_runner.cpp:188] In function main(), assert failed (status == Error::Ok): Execution of method forward failed with status 0x10

I believe it is trying to resize a tensor from (1, 32, 12, 12) to (1, 32, 11, 11), but fails because the tensor is marked as 'static'.

I know this by adding a debug breakpoint at runtime/core/portable_type/tensor_impl.cpp:144 and checking the tensor shapes.

Is this a bug?

Thanks a lot!

Versions

executorch==0.2.0
torch==2.3.0

Ivan

Metadata

Metadata

Assignees

Labels

actionableItems in the backlog waiting for an appropriate impl/fixmodule: xnnpackIssues related to xnnpack delegation and the code under backends/xnnpack/triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions