From e9a74911f57dc48c1d21e1d9309866586929fbfc Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 12 Apr 2025 20:19:09 +0800 Subject: [PATCH 1/4] Update exported_program_translator.py --- .../relax/frontend/torch/exported_program_translator.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 875ec3b83ea8..0f97092946bf 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -202,6 +202,13 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: ########## Manipulation ########## + def _narrow(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + start = node.args[2] + length = node.args[3] + return self.block_builder.emit(relax.op.strided_slice(x, [dim], [start], [length])) + def _select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] @@ -390,6 +397,7 @@ def create_convert_map( "where.self": self._where, # tensor manipulation "argsort.default": self._argsort, + "broadcast_to.default": self._broadcast_to, "cat.default": self._cat, "chunk.default": self._chunk, "clamp.Tensor": self._clamp, @@ -402,6 +410,7 @@ def create_convert_map( "flatten.using_ints": self._flatten, "flip.default": self._flip, "gather.default": self._gather, + "narrow.default": self._narrow, "permute.default": self._permute, "repeat.default": self._repeat, "select.int": self._select, From bf165dc4a2e7ebaae3628d9eb72e0aeed18dd272 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 12 Apr 2025 20:20:31 +0800 Subject: [PATCH 2/4] Update test_frontend_from_exported_program.py --- .../test_frontend_from_exported_program.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 42288cf562cd..a16e45410d0f 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3856,5 +3856,49 @@ def main( verify_model(DynamicModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes) +def test_broadcast_to(): + class BroadcastTo(Module): + def forward(self, x): + return torch.broadcast_to(x, (5, 3)) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((5, 1), dtype="float32")) -> R.Tuple(R.Tensor((5, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(x, R.shape([5, 3])) + gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,) + R.output(gv) + + return gv + + example_args = (torch.randn(5, 1, dtype=torch.float32),) + verify_model(BroadcastTo(), example_args, {}, Expected) + + +def test_narrow(): + class Narrow(Module): + def forward(self, x): + return torch.narrow(x, 1, 0, 2) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 2), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice(x, (R.prim_value(1),), (R.prim_value(0),), + (R.prim_value(2),), (R.prim_value(1),), + assume_inbound=False) + + gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,) + + R.output(gv) + + return gv + + example_args = (torch.randn(5, 3, dtype=torch.float32),) + verify_model(Narrow(), example_args, {}, Expected) + + if __name__ == "__main__": tvm.testing.main() From f25afa076278ccb89aedbee7c5d4c9b33c45cb0e Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 12 Apr 2025 23:36:24 +0800 Subject: [PATCH 3/4] Update test_frontend_from_exported_program.py --- .../test_frontend_from_exported_program.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index a16e45410d0f..37497cc8f999 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3864,7 +3864,9 @@ def forward(self, x): @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((5, 1), dtype="float32")) -> R.Tuple(R.Tensor((5, 3), dtype="float32")): + def main( + x: R.Tensor((5, 1), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")): with R.dataflow(): lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(x, R.shape([5, 3])) gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,) @@ -3884,14 +3886,19 @@ def forward(self, x): @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 2), dtype="float32")): + def main( + x: R.Tensor((5, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")): with R.dataflow(): - lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice(x, (R.prim_value(1),), (R.prim_value(0),), - (R.prim_value(2),), (R.prim_value(1),), - assume_inbound=False) - + lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice( + x, + (R.prim_value(1),), + (R.prim_value(0),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,) - R.output(gv) return gv From e17018ef8d9e5e7a0812d3383118130abc5ceacf Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 13 Apr 2025 13:38:32 +0800 Subject: [PATCH 4/4] Update test_frontend_from_exported_program.py --- tests/python/relax/test_frontend_from_exported_program.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 37497cc8f999..284544be5079 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3895,7 +3895,6 @@ def main( (R.prim_value(1),), (R.prim_value(0),), (R.prim_value(2),), - (R.prim_value(1),), assume_inbound=False, ) gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)