From e594d1a1dbc8031ced07edfa038e460c8f5c3c8c Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 14 Nov 2025 11:31:38 +0800 Subject: [PATCH] Add decomposed operator support for MaxPool --- .../torch/base_fx_graph_translator.py | 48 ++++++ .../torch/exported_program_translator.py | 2 + .../test_frontend_from_exported_program.py | 152 ++++++++++++------ 3 files changed, 155 insertions(+), 47 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 0c8cd4b34fe2..33e8347fb077 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1313,6 +1313,54 @@ def _max_pool3d(self, node: fx.Node) -> relax.Var: ceil_mode = args[5] if len(args) > 5 else False return self._max_pool3d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _max_pool1d_with_indices(self, node: fx.Node) -> relax.Var: + # max_pool1d_with_indices returns (output, indices) + # We only compute the output and create a placeholder for indices + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + output = self._max_pool1d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + # Create a placeholder for indices (empty tensor with same shape as output) + indices = relax.op.zeros_like(output) + return self.block_builder.emit(relax.Tuple([output, indices])) + + def _max_pool2d_with_indices(self, node: fx.Node) -> relax.Var: + # max_pool2d_with_indices returns (output, indices) + # We only compute the output and create a placeholder for indices + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + output = self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + # Create a placeholder for indices (empty tensor with same shape as output) + indices = relax.op.zeros_like(output) + return self.block_builder.emit(relax.Tuple([output, indices])) + + def _max_pool3d_with_indices(self, node: fx.Node) -> relax.Var: + # max_pool3d_with_indices returns (output, indices) + # We only compute the output and create a placeholder for indices + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + output = self._max_pool3d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + # Create a placeholder for indices (empty tensor with same shape as output) + indices = relax.op.zeros_like(output) + return self.block_builder.emit(relax.Tuple([output, indices])) + def _pad(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] pad = node.args[1] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a6da21ada851..5cddf24a89dc 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -986,7 +986,9 @@ def create_convert_map( "gru.input": self._gru, "max_pool1d.default": self._max_pool1d, "max_pool2d.default": self._max_pool2d, + "max_pool2d_with_indices.default": self._max_pool2d_with_indices, "max_pool3d.default": self._max_pool3d, + "max_pool3d_with_indices.default": self._max_pool3d_with_indices, "scaled_dot_product_attention.default": self._scaled_dot_product_attention, "unbind.int": self._unbind, "upsample_bilinear2d.vec": self._upsample_bilinear2d, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 774a50db0e3f..71e400a6a8b1 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3163,16 +3163,24 @@ def main( input_1: R.Tensor((1, 3, 8), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): with R.dataflow(): - lv = R.nn.max_pool1d( - input_1, - pool_size=[2], - strides=[2], - dilation=[1], - padding=[0, 0], - layout="NCW", - out_layout="NCW", + lv: R.Tensor((1, 3, 1, 8), dtype="float32") = R.expand_dims(input_1, axis=[-2]) + lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d( + lv, + pool_size=[1, 2], + strides=[1, 2], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", ) - gv = (lv,) + lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = R.zeros_like(lv1) + lv3: R.Tuple( + R.Tensor((1, 3, 1, 4), dtype="float32"), + R.Tensor((1, 3, 1, 4), dtype="float32"), + ) = (lv1, lv2) + lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0] + lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, axis=[-2]) + gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,) R.output(gv) return gv @@ -3183,16 +3191,24 @@ def main( input_1: R.Tensor((1, 3, 8), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): with R.dataflow(): - lv = R.nn.max_pool1d( - input_1, - pool_size=[2], - strides=[2], - dilation=[1], - padding=[0, 0], - layout="NCW", - out_layout="NCW", + lv: R.Tensor((1, 3, 1, 8), dtype="float32") = R.expand_dims(input_1, axis=[-2]) + lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d( + lv, + pool_size=[1, 2], + strides=[1, 2], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", ) - gv = (lv,) + lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = R.zeros_like(lv1) + lv3: R.Tuple( + R.Tensor((1, 3, 1, 4), dtype="float32"), + R.Tensor((1, 3, 1, 4), dtype="float32"), + ) = (lv1, lv2) + lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0] + lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, axis=[-2]) + gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,) R.output(gv) return gv @@ -3203,16 +3219,24 @@ def main( input_1: R.Tensor((1, 3, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): with R.dataflow(): - lv = R.nn.max_pool1d( - input_1, - pool_size=[3], - strides=[2], - dilation=[1], - padding=[0, 0], - layout="NCW", - out_layout="NCW", + lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input_1, axis=[-2]) + lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d( + lv, + pool_size=[1, 3], + strides=[1, 2], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", ) - gv = (lv,) + lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = R.zeros_like(lv1) + lv3: R.Tuple( + R.Tensor((1, 3, 1, 4), dtype="float32"), + R.Tensor((1, 3, 1, 4), dtype="float32"), + ) = (lv1, lv2) + lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0] + lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, axis=[-2]) + gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,) R.output(gv) return gv @@ -3222,9 +3246,9 @@ def main( example_args3 = (torch.randn(1, 3, 10, dtype=torch.float32),) # Verify the models - verify_model(MaxPool1d(), example_args1, {}, expected1) - verify_model(MaxPool1d_functional(), example_args2, {}, expected2) - verify_model(MaxPool1d2(), example_args3, {}, expected3) + verify_model(MaxPool1d(), example_args1, {}, expected1, run_ep_decomposition=True) + verify_model(MaxPool1d_functional(), example_args2, {}, expected2, run_ep_decomposition=True) + verify_model(MaxPool1d2(), example_args3, {}, expected3, run_ep_decomposition=True) def test_maxpool2d(): @@ -3260,7 +3284,13 @@ def main( layout="NCHW", out_layout="NCHW", ) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.zeros_like(lv) + lv2: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + ) = (lv, lv1) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv2[0] + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -3289,7 +3319,12 @@ def main( layout="NCHW", out_layout="NCHW", ) - gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv,) + lv1: R.Tensor((1, 3, 4, 4), dtype="float32") = R.zeros_like(lv) + lv2: R.Tuple( + R.Tensor((1, 3, 4, 4), dtype="float32"), R.Tensor((1, 3, 4, 4), dtype="float32") + ) = (lv, lv1) + lv3: R.Tensor((1, 3, 4, 4), dtype="float32") = lv2[0] + gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -3318,15 +3353,20 @@ def main( layout="NCHW", out_layout="NCHW", ) - gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv,) + lv1: R.Tensor((1, 3, 6, 6), dtype="float32") = R.zeros_like(lv) + lv2: R.Tuple( + R.Tensor((1, 3, 6, 6), dtype="float32"), R.Tensor((1, 3, 6, 6), dtype="float32") + ) = (lv, lv1) + lv3: R.Tensor((1, 3, 6, 6), dtype="float32") = lv2[0] + gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv3,) R.output(gv) return gv example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(MaxPool2d(), example_args, {}, expected1) - verify_model(MaxPool2d_functional(), example_args, {}, expected1) - verify_model(MaxPool2d2(), example_args, {}, expected2) - verify_model(MaxPool2d3(), example_args, {}, expected3) + verify_model(MaxPool2d(), example_args, {}, expected1, run_ep_decomposition=True) + verify_model(MaxPool2d_functional(), example_args, {}, expected1, run_ep_decomposition=True) + verify_model(MaxPool2d2(), example_args, {}, expected2, run_ep_decomposition=True) + verify_model(MaxPool2d3(), example_args, {}, expected3, run_ep_decomposition=True) def test_maxpool3d(): @@ -3352,7 +3392,7 @@ def main( input_1: R.Tensor((1, 3, 4, 4, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")): with R.dataflow(): - lv = R.nn.max_pool3d( + lv: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.nn.max_pool3d( input_1, pool_size=[1, 1, 1], strides=[1, 1, 1], @@ -3361,7 +3401,13 @@ def main( layout="NCDHW", out_layout="NCDHW", ) - gv = (lv,) + lv1: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.zeros_like(lv) + lv2: R.Tuple( + R.Tensor((1, 3, 4, 4, 4), dtype="float32"), + R.Tensor((1, 3, 4, 4, 4), dtype="float32"), + ) = (lv, lv1) + lv3: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = lv2[0] + gv: R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -3380,7 +3426,7 @@ def main( input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")): with R.dataflow(): - lv = R.nn.max_pool3d( + lv: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = R.nn.max_pool3d( input_1, pool_size=[2, 2, 2], strides=[2, 2, 2], @@ -3389,7 +3435,13 @@ def main( layout="NCDHW", out_layout="NCDHW", ) - gv = (lv,) + lv1: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = R.zeros_like(lv) + lv2: R.Tuple( + R.Tensor((1, 3, 3, 3, 3), dtype="float32"), + R.Tensor((1, 3, 3, 3, 3), dtype="float32"), + ) = (lv, lv1) + lv3: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = lv2[0] + gv: R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -3408,7 +3460,7 @@ def main( input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")): with R.dataflow(): - lv = R.nn.max_pool3d( + lv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = R.nn.max_pool3d( input_1, pool_size=[3, 3, 3], strides=[2, 2, 2], @@ -3417,7 +3469,13 @@ def main( layout="NCDHW", out_layout="NCDHW", ) - gv = (lv,) + lv1: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = R.zeros_like(lv) + lv2: R.Tuple( + R.Tensor((1, 3, 5, 5, 5), dtype="float32"), + R.Tensor((1, 3, 5, 5, 5), dtype="float32"), + ) = (lv, lv1) + lv3: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = lv2[0] + gv: R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -3427,10 +3485,10 @@ def main( example_args3 = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),) # Verify the models with expected IR modules - verify_model(MaxPool3d(), example_args1, {}, expected1) - verify_model(MaxPool3d_functional(), example_args1, {}, expected1) - verify_model(MaxPool3d2(), example_args2, {}, expected2) - verify_model(MaxPool3d3(), example_args3, {}, expected3) + verify_model(MaxPool3d(), example_args1, {}, expected1, run_ep_decomposition=True) + verify_model(MaxPool3d_functional(), example_args1, {}, expected1, run_ep_decomposition=True) + verify_model(MaxPool3d2(), example_args2, {}, expected2, run_ep_decomposition=True) + verify_model(MaxPool3d3(), example_args3, {}, expected3, run_ep_decomposition=True) def test_scaled_dot_product_attention():