From 30112548dcabc0d3a912f70d705aa28e5476b77c Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 14 Nov 2025 17:23:13 +0800 Subject: [PATCH 1/3] Add decomposed operator support for Pad --- .../torch/base_fx_graph_translator.py | 48 ++++- .../torch/exported_program_translator.py | 2 + .../test_frontend_from_exported_program.py | 199 ++++++++++++++++-- 3 files changed, 227 insertions(+), 22 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 33e8347fb077..92749dfdce86 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1379,6 +1379,23 @@ def _pad(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, value)) + def _constant_pad_nd(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + pad = node.args[1] + value = node.args[2] if len(node.args) > 2 else node.kwargs.get("value", 0.0) + value = 0.0 if value is None else value + + # Calculate symmetric padding width for each dimension + # and applying them in reverse order to match the input dimensions. + input_ndim = x.struct_info.ndim + pad_width = [0] * (input_ndim * 2) + pad_pairs = [pad[i : i + 2] for i in range(0, len(pad), 2)] + reversed_pairs = list(reversed(pad_pairs)) + flattened = [v for pair in reversed_pairs for v in pair] + pad_width[-len(flattened) :] = flattened + + return self.block_builder.emit(relax.op.nn.pad(x, pad_width, "constant", value)) + def _pixel_shuffle(self, node: fx.Node) -> relax.Var: data = self.env[node.args[0]] upscale_factor = node.args[1] @@ -1665,8 +1682,37 @@ def _index_put(self, node: fx.Node) -> relax.Var: def _index_tensor(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) + data = args[0] indices = args[1] - return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) + + # In PyTorch's aten.index.Tensor, None means "select all elements" for that dimension + non_none_indices = [(i, idx) for i, idx in enumerate(indices) if idx is not None] + + # Special case: if there's only one non-None index, use take operation + if len(non_none_indices) == 1: + axis, index_tensor = non_none_indices[0] + return self.block_builder.emit(relax.op.take(data, index_tensor, axis=axis)) + + # General case: multiple non-None indices require advanced indexing + processed_indices = [] + data_shape = self.shape_of(data) + + for i, idx in enumerate(indices): + if idx is None: + dim_size = data_shape[i] + arange_idx = self.block_builder.emit( + relax.op.arange( + start=relax.PrimValue(0), + end=dim_size, + step=relax.PrimValue(1), + dtype="int64" + ) + ) + processed_indices.append(arange_idx) + else: + processed_indices.append(idx) + + return self.block_builder.emit(relax.op.index_tensor(data, processed_indices)) def _meshgrid(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 5cddf24a89dc..409265427342 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -862,6 +862,8 @@ def create_convert_map( "_log_softmax.default": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), "pad.default": self._pad, + "constant_pad_nd.default": self._constant_pad_nd, + "copy.default": self._copy_, "pixel_shuffle.default": self._pixel_shuffle, "prelu.default": self._prelu, "reciprocal.default": self._reciprocal, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 71e400a6a8b1..ab4b7af24b54 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2715,13 +2715,25 @@ def main( x: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( - x, - pad_width=[0, 0, 0, 0, 2, 2, 1, 1], - pad_mode="reflect", - pad_value=0.0, + lv: R.Tensor((14,), dtype="int64") = R.arange( + R.prim_value(-2), R.prim_value(12), R.prim_value(1), dtype="int64" ) - gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + lv1: R.Tensor((14,), dtype="int64") = R.abs(lv) + lv2: R.Tensor((14,), dtype="int64") = R.subtract(R.const(9, "int64"), lv1) + lv3: R.Tensor((14,), dtype="int64") = R.abs(lv2) + lv4: R.Tensor((14,), dtype="int64") = R.subtract(R.const(9, "int64"), lv3) + lv5: R.Tensor((1, 3, 14, 10), dtype="float32") = R.take(x, lv4, axis=2, mode="fast") + lv6: R.Tensor((12,), dtype="int64") = R.arange( + R.prim_value(-1), R.prim_value(11), R.prim_value(1), dtype="int64" + ) + lv7: R.Tensor((12,), dtype="int64") = R.abs(lv6) + lv8: R.Tensor((12,), dtype="int64") = R.subtract(R.const(9, "int64"), lv7) + lv9: R.Tensor((12,), dtype="int64") = R.abs(lv8) + lv10: R.Tensor((12,), dtype="int64") = R.subtract(R.const(9, "int64"), lv9) + lv11: R.Tensor((1, 3, 14, 12), dtype="float32") = R.take( + lv5, lv10, axis=3, mode="fast" + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv11,) R.output(gv) return gv @@ -2732,13 +2744,19 @@ def main( x: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( - x, - pad_width=[0, 0, 0, 0, 2, 2, 1, 1], - pad_mode="replicate", - pad_value=0.0, + lv: R.Tensor((14,), dtype="int64") = R.arange( + R.prim_value(-2), R.prim_value(12), R.prim_value(1), dtype="int64" ) - gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + lv1: R.Tensor((14,), dtype="int64") = R.clip(lv, R.prim_value(0), R.prim_value(9)) + lv2: R.Tensor((1, 3, 14, 10), dtype="float32") = R.take(x, lv1, axis=2, mode="fast") + lv3: R.Tensor((12,), dtype="int64") = R.arange( + R.prim_value(-1), R.prim_value(11), R.prim_value(1), dtype="int64" + ) + lv4: R.Tensor((12,), dtype="int64") = R.clip(lv3, R.prim_value(0), R.prim_value(9)) + lv5: R.Tensor((1, 3, 14, 12), dtype="float32") = R.take( + lv2, lv4, axis=3, mode="fast" + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv5,) R.output(gv) return gv @@ -2749,21 +2767,160 @@ def main( x: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.zeros( + R.shape([1, 3, 14, 12]), dtype="float32" + ) + lv1: R.Tensor((1, 3, 14, 10), dtype="float32") = R.strided_slice( + lv, + (R.prim_value(3),), + (R.prim_value(1),), + (R.prim_value(11),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( x, - pad_width=[0, 0, 0, 0, 2, 2, 1, 1], - pad_mode="circular", - pad_value=0.0, + (R.prim_value(3),), + (R.prim_value(0),), + (R.prim_value(10),), + (R.prim_value(1),), + assume_inbound=False, ) - gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + lv1, + (R.prim_value(2),), + (R.prim_value(2),), + (R.prim_value(12),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + lv2, + (R.prim_value(2),), + (R.prim_value(0),), + (R.prim_value(10),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv5: R.Tensor((1, 3, 14, 10), dtype="float32") = R.strided_slice( + lv, + (R.prim_value(3),), + (R.prim_value(1),), + (R.prim_value(11),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv6: R.Tensor((1, 3, 14, 10), dtype="float32") = R.slice_scatter( + lv5, lv4, R.prim_value(2), R.prim_value(12), R.prim_value(1), axis=2 + ) + lv7: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter( + lv, lv6, R.prim_value(1), R.prim_value(11), R.prim_value(1), axis=3 + ) + lv8: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice( + lv7, + (R.prim_value(3),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv9: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice( + lv7, + (R.prim_value(3),), + (R.prim_value(10),), + (R.prim_value(11),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv10: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter( + lv7, lv9, R.prim_value(0), R.prim_value(1), R.prim_value(1), axis=3 + ) + lv11: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice( + lv10, + (R.prim_value(3),), + (R.prim_value(11),), + (R.prim_value(12),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv12: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice( + lv10, + (R.prim_value(3),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv13: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter( + lv10, lv12, R.prim_value(11), R.prim_value(12), R.prim_value(1), axis=3 + ) + lv14: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice( + lv13, + (R.prim_value(2),), + (R.prim_value(0),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv15: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice( + lv13, + (R.prim_value(2),), + (R.prim_value(10),), + (R.prim_value(12),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv16: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter( + lv13, lv15, R.prim_value(0), R.prim_value(2), R.prim_value(1), axis=2 + ) + lv17: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice( + lv16, + (R.prim_value(2),), + (R.prim_value(12),), + (R.prim_value(14),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv18: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice( + lv16, + (R.prim_value(2),), + (R.prim_value(2),), + (R.prim_value(4),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv19: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter( + lv16, lv18, R.prim_value(12), R.prim_value(14), R.prim_value(1), axis=2 + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv19,) R.output(gv) return gv example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant) - verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), example_args, {}, expected_reflect) - verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), example_args, {}, expected_replicate) - verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args, {}, expected_circular) + verify_model( + PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant, run_ep_decomposition=True + ) + verify_model( + PadModel(pad=[1, 1, 2, 2], mode="reflect"), + example_args, + {}, + expected_reflect, + run_ep_decomposition=True, + ) + verify_model( + PadModel(pad=[1, 1, 2, 2], mode="replicate"), + example_args, + {}, + expected_replicate, + run_ep_decomposition=True, + ) + verify_model( + PadModel(pad=[1, 1, 2, 2], mode="circular"), + example_args, + {}, + expected_circular, + run_ep_decomposition=True, + ) def test_pixel_shuffle(): From 925f9e4ba9b2e65d1633b0babea55f6ae22081d4 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 14 Nov 2025 17:57:12 +0800 Subject: [PATCH 2/3] Fix lint error --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 92749dfdce86..7b8c51895c98 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1705,7 +1705,7 @@ def _index_tensor(self, node: fx.Node) -> relax.Var: start=relax.PrimValue(0), end=dim_size, step=relax.PrimValue(1), - dtype="int64" + dtype="int64", ) ) processed_indices.append(arange_idx) From 5997eacfc262e07603f5f63bf9c1d7c6784dbc58 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 14 Nov 2025 23:17:21 +0800 Subject: [PATCH 3/3] Fix test_take --- tests/python/relax/test_frontend_from_exported_program.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ab4b7af24b54..a59581a86195 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6106,7 +6106,7 @@ def main( ) -> R.Tuple(R.Tensor((3,), dtype="float32")): with R.dataflow(): lv: R.Tensor((5,), dtype="float32") = R.reshape(data, R.shape([5])) - lv1: R.Tensor((3,), dtype="float32") = R.index_tensor(lv, (indices,)) + lv1: R.Tensor((3,), dtype="float32") = R.take(lv, indices, axis=0, mode="fast") gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv1,) R.output(gv) return gv