From 2b57d57a721cc65966fe4c6e8f8a5a922b08c337 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 22 Aug 2024 21:41:32 +0900 Subject: [PATCH 1/2] add test --- tests/python/relax/test_frontend_from_fx.py | 42 +++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 1a2cc5da6242..6be3e7b23e9d 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3126,6 +3126,48 @@ def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype= verify_model(Reshape(), input_info, {}, expected1) +def test_tile(): + input_info = [([1, 3], "float32")] + + class Tile1(Module): + def forward(self, x): + return x.tile((2,)) + + class Tile2(Module): + def forward(self, x): + return x.tile(4, 2) + + class Tile3(Module): + def forward(self, x): + return torch.tile(x, (4, 2)) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((1, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2]) + gv: R.Tensor((1, 6), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + gv: R.Tensor((4, 6), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tile1(), input_info, {}, expected1) + verify_model(Tile2(), input_info, {}, expected2) + verify_model(Tile3(), input_info, {}, expected2) + + def test_transpose(): input_info = [([1, 2, 3, 4], "float32")] From 9f347e77c7b15ef7558e90e83b5e4edce1bb2125 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 22 Aug 2024 21:41:59 +0900 Subject: [PATCH 2/2] add support for torch.tile --- python/tvm/relax/frontend/torch/fx_translator.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 093f3ae4cf7a..35131d324076 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -612,6 +612,14 @@ def _squeeze(self, node: fx.node.Node) -> relax.Var: dim = None return self.block_builder.emit(relax.op.squeeze(x, dim)) + def _tile(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.tile(args[0], args[1:])) + def _cumsum(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] @@ -1450,6 +1458,7 @@ def create_convert_map(self): "permute": self._permute, "reshape": self._reshape, "split": self._split, + "tile": self._tile, "cumsum": self._cumsum, "chunk": self._chunk, "transpose": self._transpose,