From e9338589f85303f7a98fd25aaf4ed9b64c4b5376 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:14:56 +0900 Subject: [PATCH 01/20] support cat and concat --- .../torch/base_fx_graph_translator.py | 5 ++ .../torch/exported_program_translator.py | 2 + .../tvm/relax/frontend/torch/fx_translator.py | 5 -- .../test_frontend_from_exported_program.py | 50 +++++++++++++++++++ 4 files changed, 57 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 52784dc8c3cd..f7d66727d986 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -730,6 +730,11 @@ def convert(node: fx.Node): ########## Manipulation ########## + def _cat(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + def _reshape(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 64583d750974..b343cc359203 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -249,6 +249,8 @@ def create_convert_map( "argmax.default": self._argmax_argmin(relax.op.argmax), "argmin.default": self._argmax_argmin(relax.op.argmin), # tensor manipulation + "cat.default": self._cat, + "concat.default": self._cat, "view.default": self._reshape, # other "getitem": self._getitem, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index c60c7c3953b4..46ab12312021 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -380,11 +380,6 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: ########## Manipulation ########## - def _cat(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) - def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] chunks = node.args[1] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 7c887d9b9610..ad233004419d 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2734,6 +2734,56 @@ def main( verify_model(Argmin2(), example_args, {}, expected_argmin2) +def test_cat_concat(): + class Cat0(Module): + def forward(self, x, y): + return torch.cat((x, y)) + + class Cat1(Module): + def forward(self, x, y): + return torch.cat((x, y), dim=1) + + class Cat2(Module): + def forward(self, x, y): + return torch.cat((x, y), 1) + + class Cat3(Module): + def forward(self, x, y): + return torch.concat((x, y), dim=0) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 3), dtype="float32") = R.concat((inp_0, inp_1), axis=0) + gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @I.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 6), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 6), dtype="float32") = R.concat((inp_0, inp_1), axis=1) + gv: R.Tuple(R.Tensor((2, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, dtype=torch.float32)) + verify_model(Cat0(), example_args, {}, Expected1) + verify_model(Cat1(), example_args, {}, Expected2) + verify_model(Cat2(), example_args, {}, Expected2) + verify_model(Cat3(), example_args, {}, Expected1) + + def test_view(): class View(Module): def forward(self, x): From 73109e0a0bbec1d09e0b6f90605fbb5a2edec22f Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:15:18 +0900 Subject: [PATCH 02/20] support cumsum --- .../torch/base_fx_graph_translator.py | 13 +++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 13 ----------- .../test_frontend_from_exported_program.py | 22 +++++++++++++++++++ 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f7d66727d986..efcb6f715b94 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -735,6 +735,19 @@ def _cat(self, node: fx.Node) -> relax.Var: axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + def _cumsum(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + if "dtype" in node.kwargs: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + else: + dtype = None + if "out" in node.kwargs: + raise ValueError("specifying out for cumsum is not supported yet") + + return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) + def _reshape(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index b343cc359203..a0faf41ab73f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -251,6 +251,7 @@ def create_convert_map( # tensor manipulation "cat.default": self._cat, "concat.default": self._cat, + "cumsum.default": self._cumsum, "view.default": self._reshape, # other "getitem": self._getitem, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 46ab12312021..76759829f454 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -386,19 +386,6 @@ def _chunk(self, node: fx.Node) -> relax.Var: dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.split(x, chunks, dim)) - def _cumsum(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) - if "dtype" in node.kwargs: - dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) - else: - dtype = None - if "out" in node.kwargs: - raise ValueError("specifying out for cumsum is not supported yet") - - return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) - def _expand(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) sizes = args[1:] if len(args) > 2 else args[1] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ad233004419d..cc70c17f8cf4 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2784,6 +2784,28 @@ def main( verify_model(Cat3(), example_args, {}, Expected1) +def test_cumsum(): + class Cumsum(Module): + def forward(self, input): + return torch.cumsum(input, dim=1, dtype=torch.int32) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="int32") = R.cumsum(input_1, axis=1, dtype="int32") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Cumsum(), example_args, {}, expected1) + + def test_view(): class View(Module): def forward(self, x): From 8aca98934aa9e97aea97cc5ee70cdffbec899881 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:16:20 +0900 Subject: [PATCH 03/20] support expand --- .../torch/base_fx_graph_translator.py | 11 ++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 11 -------- .../test_frontend_from_exported_program.py | 27 +++++++++++++++++++ 4 files changed, 39 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index efcb6f715b94..c98c9ec6e40b 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -748,6 +748,17 @@ def _cumsum(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) + def _expand(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + sizes = args[1:] if len(args) > 2 else args[1] + broadcast_shape, in_shape = [], self.shape_of(args[0]) + for idx, i in enumerate(sizes): + if isinstance(i, int) and i == -1: + broadcast_shape.append(in_shape[idx]) + else: + broadcast_shape.append(i) + return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) + def _reshape(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a0faf41ab73f..6da1e2672a8c 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -252,6 +252,7 @@ def create_convert_map( "cat.default": self._cat, "concat.default": self._cat, "cumsum.default": self._cumsum, + "expand.default": self._expand, "view.default": self._reshape, # other "getitem": self._getitem, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 76759829f454..3a29b0dd76f7 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -386,17 +386,6 @@ def _chunk(self, node: fx.Node) -> relax.Var: dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.split(x, chunks, dim)) - def _expand(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - sizes = args[1:] if len(args) > 2 else args[1] - broadcast_shape, in_shape = [], self.shape_of(args[0]) - for idx, i in enumerate(sizes): - if isinstance(i, int) and i == -1: - broadcast_shape.append(in_shape[idx]) - else: - broadcast_shape.append(i) - return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) - def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var: shape = self.shape_of(x) start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index cc70c17f8cf4..a82acd6b8a6c 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2806,6 +2806,33 @@ def main( verify_model(Cumsum(), example_args, {}, expected1) +def test_expand(): + class Expand1(Module): + def forward(self, x): + return x.expand(4, 2, 3, 4) + + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 2, 3, 4), dtype="float32") = R.broadcast_to(x, (4, 2, 3, 4)) + gv: R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Expand1(), example_args, {}, expected1) + verify_model(Expand2(), example_args, {}, expected1) + + def test_view(): class View(Module): def forward(self, x): From d0f416ab432dab8c9809963c178f31a53756fdbe Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:19:58 +0900 Subject: [PATCH 04/20] support permute --- .../torch/base_fx_graph_translator.py | 8 ++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 8 ------ .../test_frontend_from_exported_program.py | 27 +++++++++++++++++++ 4 files changed, 36 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c98c9ec6e40b..4271ff40d0b7 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -759,6 +759,14 @@ def _expand(self, node: fx.Node) -> relax.Var: broadcast_shape.append(i) return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) + def _permute(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.permute_dims(x, dims)) + def _reshape(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 6da1e2672a8c..090642846439 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -253,6 +253,7 @@ def create_convert_map( "concat.default": self._cat, "cumsum.default": self._cumsum, "expand.default": self._expand, + "permute.default": self._permute, "view.default": self._reshape, # other "getitem": self._getitem, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 3a29b0dd76f7..49cb9af63196 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -411,14 +411,6 @@ def _flatten_module(self, node: fx.Node) -> relax.Var: end_dim = module.end_dim return self._flatten_impl(x, start_dim, end_dim) - def _permute(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - x = args[0] - dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.permute_dims(x, dims)) - def _repeat(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index a82acd6b8a6c..99063bfb0e17 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2833,6 +2833,33 @@ def main( verify_model(Expand2(), example_args, {}, expected1) +def test_permute(): + class Permute1(Module): + def forward(self, x): + return x.permute(0, 3, 2, 1) + + class Permute2(Module): + def forward(self, x): + return torch.permute(x, (0, 3, 2, 1)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Permute1(), example_args, {}, expected1) + verify_model(Permute2(), example_args, {}, expected1) + + def test_view(): class View(Module): def forward(self, x): From f258ff02dedc8ae42aaee29a8c1c97e0cf30cb50 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:26:26 +0900 Subject: [PATCH 05/20] support squeeze --- .../torch/base_fx_graph_translator.py | 5 +++ .../torch/exported_program_translator.py | 2 + .../tvm/relax/frontend/torch/fx_translator.py | 5 --- .../test_frontend_from_exported_program.py | 39 +++++++++++++++++++ 4 files changed, 46 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4271ff40d0b7..1559cda635f3 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -775,6 +775,11 @@ def _reshape(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.reshape(x, dims)) + def _squeeze(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + return self.block_builder.emit(relax.op.squeeze(x, dim)) + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 090642846439..caa3022f20dc 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -254,6 +254,8 @@ def create_convert_map( "cumsum.default": self._cumsum, "expand.default": self._expand, "permute.default": self._permute, + "squeeze.default": self._squeeze, + "squeeze.dim": self._squeeze, "view.default": self._reshape, # other "getitem": self._getitem, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 49cb9af63196..6c3d622960c7 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -442,11 +442,6 @@ def _split(self, node: fx.Node) -> relax.Var: n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size return self.block_builder.emit(relax.op.split(x, n_section, dim)) - def _squeeze(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) - return self.block_builder.emit(relax.op.squeeze(x, dim)) - def _tile(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 99063bfb0e17..ecf894db7dc1 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2860,6 +2860,45 @@ def main( verify_model(Permute2(), example_args, {}, expected1) +def test_squeeze(): + class Squeeze1(Module): + def forward(self, input): + return input.squeeze(1) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 4, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[1]) + gv: R.Tuple(R.Tensor((3, 4, 1), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Squeeze2(Module): + def forward(self, input): + return input.squeeze() + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, axis=None) + gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),) + + verify_model(Squeeze1(), example_args, {}, Expected1) + verify_model(Squeeze2(), example_args, {}, Expected2) + + def test_view(): class View(Module): def forward(self, x): From eb04e267ceac0314ab44f1a23516dfc050213e84 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:27:20 +0900 Subject: [PATCH 06/20] support tile --- .../torch/base_fx_graph_translator.py | 8 ++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 8 ---- .../test_frontend_from_exported_program.py | 45 +++++++++++++++++++ 4 files changed, 54 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 1559cda635f3..abd766ccc927 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -780,6 +780,14 @@ def _squeeze(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) return self.block_builder.emit(relax.op.squeeze(x, dim)) + def _tile(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.tile(x, dims)) + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index caa3022f20dc..b27abd008eb7 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -256,6 +256,7 @@ def create_convert_map( "permute.default": self._permute, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, + "tile.default": self._tile, "view.default": self._reshape, # other "getitem": self._getitem, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 6c3d622960c7..5533552f36f9 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -442,14 +442,6 @@ def _split(self, node: fx.Node) -> relax.Var: n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size return self.block_builder.emit(relax.op.split(x, n_section, dim)) - def _tile(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - x = args[0] - dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.tile(x, dims)) - def _transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) full_idx = list(range(len(self.shape_of(args[0])))) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ecf894db7dc1..f9ad0377c2c6 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2899,6 +2899,51 @@ def main( verify_model(Squeeze2(), example_args, {}, Expected2) +def test_tile(): + class Tile1(Module): + def forward(self, x): + return x.tile((2,)) + + class Tile2(Module): + def forward(self, x): + return x.tile(4, 2) + + class Tile3(Module): + def forward(self, x): + return torch.tile(x, (4, 2)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2]) + gv: R.Tuple(R.Tensor((1, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((1, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, dtype=torch.float32),) + verify_model(Tile1(), example_args, {}, expected1) + verify_model(Tile2(), example_args, {}, expected2) + verify_model(Tile3(), example_args, {}, expected2) + + def test_view(): class View(Module): def forward(self, x): From 649d8512c797338298c6b3c7252cdb9ac9e930e9 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:28:38 +0900 Subject: [PATCH 07/20] support transpose --- .../torch/base_fx_graph_translator.py | 6 +++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 6 ----- .../test_frontend_from_exported_program.py | 22 +++++++++++++++++++ 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index abd766ccc927..51fb1507cd15 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -788,6 +788,12 @@ def _tile(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.tile(x, dims)) + def _transpose(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + full_idx = list(range(len(self.shape_of(args[0])))) + full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] + return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index b27abd008eb7..2eba763b0b85 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -257,6 +257,7 @@ def create_convert_map( "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, "tile.default": self._tile, + "transpose.int": self._transpose, "view.default": self._reshape, # other "getitem": self._getitem, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5533552f36f9..536fbb4d61e0 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -442,12 +442,6 @@ def _split(self, node: fx.Node) -> relax.Var: n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size return self.block_builder.emit(relax.op.split(x, n_section, dim)) - def _transpose(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - full_idx = list(range(len(self.shape_of(args[0])))) - full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] - return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) - ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f9ad0377c2c6..4da9fedd69ed 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2944,6 +2944,28 @@ def main( verify_model(Tile3(), example_args, {}, expected2) +def test_transpose(): + class Transpose(Module): + def forward(self, x): + return x.transpose(1, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Transpose(), example_args, {}, expected1) + + def test_view(): class View(Module): def forward(self, x): From 76412f9e2cfda58083be575e7ba74cef5876cd36 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:29:17 +0900 Subject: [PATCH 08/20] support unsqueeze --- .../torch/exported_program_translator.py | 3 ++ .../test_frontend_from_exported_program.py | 41 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2eba763b0b85..f0588c98fec5 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -258,6 +258,9 @@ def create_convert_map( "squeeze.dim": self._squeeze, "tile.default": self._tile, "transpose.int": self._transpose, + "unsqueeze.default": lambda node: self.block_builder.emit( + relax.op.expand_dims(self.env[node.args[0]], node.args[1]) + ), "view.default": self._reshape, # other "getitem": self._getitem, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 4da9fedd69ed..566e831c9131 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2966,6 +2966,47 @@ def main( verify_model(Transpose(), example_args, {}, expected1) +def test_unsqueeze(): + class Unsqueeze1(Module): + def forward(self, input): + return input.unsqueeze(1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = R.expand_dims(input_1, 1) + gv: R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Unsqueeze2(Module): + def forward(self, input): + return input.unsqueeze(-1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = R.expand_dims(input_1, -1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + verify_model(Unsqueeze1(), example_args, {}, expected1) + verify_model(Unsqueeze2(), example_args, {}, expected2) + + def test_view(): class View(Module): def forward(self, x): From b743084bf5bf48da92a79bf4fd179dbd03f4c793 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:30:47 +0900 Subject: [PATCH 09/20] add test for flatten --- .../test_frontend_from_exported_program.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 566e831c9131..4820cbc8b1dd 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2833,6 +2833,32 @@ def main( verify_model(Expand2(), example_args, {}, expected1) +def test_flatten(): + class Flatten(Module): + def __init__(self): + super().__init__() + self.f = torch.nn.Flatten(2, -1) + + def forward(self, input): + return self.f(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 100), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 100), dtype="float32") = R.reshape(input_1, (1, 3, 100)) + gv: R.Tuple(R.Tensor((1, 3, 100), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Flatten(), example_args, {}, expected1) + + def test_permute(): class Permute1(Module): def forward(self, x): From 615a35016496d99ce6f19f111f1dd256bd56b56f Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:31:29 +0900 Subject: [PATCH 10/20] support repeat --- .../torch/base_fx_graph_translator.py | 8 ++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 8 ---- .../test_frontend_from_exported_program.py | 43 +++++++++++++++++++ 4 files changed, 52 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 51fb1507cd15..a7556042dcaf 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -767,6 +767,14 @@ def _permute(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.permute_dims(x, dims)) + def _repeat(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.tile(x, dims)) + def _reshape(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index f0588c98fec5..83df02d6a05c 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -254,6 +254,7 @@ def create_convert_map( "cumsum.default": self._cumsum, "expand.default": self._expand, "permute.default": self._permute, + "repeat.default": self._repeat, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, "tile.default": self._tile, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 536fbb4d61e0..27061f534489 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -411,14 +411,6 @@ def _flatten_module(self, node: fx.Node) -> relax.Var: end_dim = module.end_dim return self._flatten_impl(x, start_dim, end_dim) - def _repeat(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - x = args[0] - dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.tile(x, dims)) - def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 4820cbc8b1dd..d9dc6ddece70 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2886,6 +2886,49 @@ def main( verify_model(Permute2(), example_args, {}, expected1) +def test_repeat(): + class Tile1(Module): + def forward(self, x: torch.Tensor): + return x.repeat(2) + + class Tile2(Module): + def forward(self, x: torch.Tensor): + return x.repeat(4, 2) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((6,), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2) + gv: R.Tuple(R.Tensor((6,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((1, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(3, dtype=torch.float32),) + verify_model(Tile1(), example_args, {}, expected1) + + example_args = (torch.randn(1, 3, dtype=torch.float32),) + verify_model(Tile2(), example_args, {}, expected2) + + example_args = (torch.randn(1, 3, dtype=torch.float32),) + verify_model(Tile2(), example_args, {}, expected2) + + def test_squeeze(): class Squeeze1(Module): def forward(self, input): From e46f6cda2346bdd4ca379ec22ae5842284b4a793 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:32:21 +0900 Subject: [PATCH 11/20] add test for reshape --- .../test_frontend_from_exported_program.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index d9dc6ddece70..14c093b5faae 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2929,6 +2929,28 @@ def main( verify_model(Tile2(), example_args, {}, expected2) +def test_reshape(): + class Reshape(Module): + def forward(self, x): + return x.reshape(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Reshape(), example_args, {}, expected1) + + def test_squeeze(): class Squeeze1(Module): def forward(self, input): From 78938db8618a0e0d0f2d71e8aa257cacafc8e9e7 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:34:55 +0900 Subject: [PATCH 12/20] support select and slice --- .../torch/exported_program_translator.py | 18 ++++ .../test_frontend_from_exported_program.py | 83 +++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 83df02d6a05c..5cadd679a947 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -162,6 +162,22 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None) return self._upsample_impl(x, size, align_corners, scale_factor, "nearest_neighbor") + ########## Manipulation ########## + + def _select(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = relax.const(node.args[2], "int64") + return self.block_builder.emit(relax.op.take(x, index, dim)) + + def _slice(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + axes = [node.args[1]] + begin = [node.args[2]] + end = [node.args[3]] + stride = [node.args[4] if len(node.args) > 4 else 1] + return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: @@ -255,6 +271,8 @@ def create_convert_map( "expand.default": self._expand, "permute.default": self._permute, "repeat.default": self._repeat, + "select.int": self._select, + "slice.Tensor": self._slice, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, "tile.default": self._tile, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 14c093b5faae..7ce397ce69ea 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2951,6 +2951,89 @@ def main( verify_model(Reshape(), example_args, {}, expected1) +def test_select_slice(): + class Slice1(Module): + def forward(self, x): + return x[0, 1::2, :, :3] + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 10, 3), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((3, 10, 10), dtype="float32") = R.take(x, R.const(0, "int64"), axis=0) + lv1: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice( + lv, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(9223372036854775807),), + (R.prim_value(2),), + assume_inbound=False, + ) + lv2: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice( + lv1, + (R.prim_value(1),), + (R.prim_value(0),), + (R.prim_value(9223372036854775807),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv3: R.Tensor((1, 10, 3), dtype="float32") = R.strided_slice( + lv2, + (R.prim_value(2),), + (R.prim_value(0),), + (R.prim_value(3),), + (R.prim_value(1),), + assume_inbound=False, + ) + gv: R.Tuple(R.Tensor((1, 10, 3), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class Slice2(Module): + def forward(self, x): + return x[:, None, None, :, None] + + @I.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((8, 16), dtype="float32") + ) -> R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice( + x, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(9223372036854775807),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv1: R.Tensor((8, 1, 16), dtype="float32") = R.expand_dims(lv, axis=[1]) + lv2: R.Tensor((8, 1, 1, 16), dtype="float32") = R.expand_dims(lv1, axis=[2]) + lv3: R.Tensor((8, 1, 1, 16), dtype="float32") = R.strided_slice( + lv2, + (R.prim_value(3),), + (R.prim_value(0),), + (R.prim_value(9223372036854775807),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv4: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = R.expand_dims(lv3, axis=[4]) + gv: R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")) = (lv4,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Slice1(), example_args, {}, expected1) + + example_args = (torch.randn(8, 16, dtype=torch.float32),) + verify_model(Slice2(), example_args, {}, expected2) + + def test_squeeze(): class Squeeze1(Module): def forward(self, input): From 58164c0dc0ee2f0665462b15676088d7805935c5 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:37:04 +0900 Subject: [PATCH 13/20] support arange --- .../torch/base_fx_graph_translator.py | 45 +++++++++++++++++++ .../torch/exported_program_translator.py | 2 + .../tvm/relax/frontend/torch/fx_translator.py | 43 ------------------ .../test_frontend_from_exported_program.py | 21 +++++++++ 4 files changed, 68 insertions(+), 43 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index a7556042dcaf..b85effd00185 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -802,6 +802,51 @@ def _transpose(self, node: fx.Node) -> relax.Var: full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) + ########## Creation ########## + + def _arange(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + start_end_step = [None, None, None] + if "start" in node.kwargs: + start_end_step[0] = node.kwargs["start"] + if "end" in node.kwargs: + start_end_step[1] = node.kwargs["end"] + if "step" in node.kwargs: + start_end_step[2] = node.kwargs["step"] + + if len(node.args) == 1: + assert start_end_step[1] is None + start_end_step[1] = node.args[0] + elif len(node.args) == 2: + assert start_end_step[0] is None + assert start_end_step[1] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + elif len(node.args) == 3: + assert start_end_step[0] is None + assert start_end_step[1] is None + assert start_end_step[2] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + start_end_step[2] = node.args[2] + + if start_end_step[0] is None: + start_end_step[0] = 0 + if start_end_step[2] is None: + start_end_step[2] = 1 + + if "dtype" in node.kwargs: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + elif any([isinstance(x, float) for x in start_end_step]): + dtype = self._convert_data_type(torch.get_default_dtype()) + else: + dtype = "int64" + start_end_step = [ + self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step + ] + return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 5cadd679a947..0088998fda82 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -281,6 +281,8 @@ def create_convert_map( relax.op.expand_dims(self.env[node.args[0]], node.args[1]) ), "view.default": self._reshape, + # tensor creation + "arange.start": self._arange, # other "getitem": self._getitem, } diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 27061f534489..f3fb9a39d6e7 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -436,49 +436,6 @@ def _split(self, node: fx.Node) -> relax.Var: ########## Creation ########## - def _arange(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - start_end_step = [None, None, None] - if "start" in node.kwargs: - start_end_step[0] = node.kwargs["start"] - if "end" in node.kwargs: - start_end_step[1] = node.kwargs["end"] - if "step" in node.kwargs: - start_end_step[2] = node.kwargs["step"] - - if len(node.args) == 1: - assert start_end_step[1] is None - start_end_step[1] = node.args[0] - elif len(node.args) == 2: - assert start_end_step[0] is None - assert start_end_step[1] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - elif len(node.args) == 3: - assert start_end_step[0] is None - assert start_end_step[1] is None - assert start_end_step[2] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - start_end_step[2] = node.args[2] - - if start_end_step[0] is None: - start_end_step[0] = 0 - if start_end_step[2] is None: - start_end_step[2] = 1 - - if "dtype" in node.kwargs: - dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) - elif any([isinstance(x, float) for x in start_end_step]): - dtype = self._convert_data_type(torch.get_default_dtype()) - else: - dtype = "int64" - start_end_step = [ - self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step - ] - return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) - def _empty(self, node: fx.Node) -> relax.Var: dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 7ce397ce69ea..91e14fdbb3fa 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3203,6 +3203,27 @@ def main( verify_model(View(), example_args, {}, expected1) +def test_arange(): + class Arange(Module): + def forward(self, input): + return torch.arange(0, 20, dtype=torch.int32) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((20,), dtype="int32")): + with R.dataflow(): + lv: R.Tensor((20,), dtype="int32") = R.arange(0, 20, 1, dtype="int32") + gv: R.Tuple(R.Tensor((20,), dtype="int32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(Arange(), example_args, {}, Expected) + + def test_keep_params(): class Conv2D1(Module): def __init__(self): From 009dee68df04e4b26dc317d6fa0b7918a9df5744 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:38:29 +0900 Subject: [PATCH 14/20] support empty --- .../torch/base_fx_graph_translator.py | 4 ++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 4 ---- .../test_frontend_from_exported_program.py | 23 +++++++++++++++++++ 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index b85effd00185..cbb0dac6f374 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -847,6 +847,10 @@ def _arange(self, node: fx.Node) -> relax.Var: ] return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) + def _empty(self, node: fx.Node) -> relax.Var: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0088998fda82..5fabc05a76d8 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -283,6 +283,7 @@ def create_convert_map( "view.default": self._reshape, # tensor creation "arange.start": self._arange, + "empty.memory_format": self._empty, # other "getitem": self._getitem, } diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index f3fb9a39d6e7..daa38099fcc3 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -436,10 +436,6 @@ def _split(self, node: fx.Node) -> relax.Var: ########## Creation ########## - def _empty(self, node: fx.Node) -> relax.Var: - dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) - return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) - def _inplace_fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 91e14fdbb3fa..e59efcb2f0d6 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3224,6 +3224,29 @@ def main( verify_model(Arange(), example_args, {}, Expected) +def test_empty(): + class Empty(Module): + def forward(self, input): + return torch.empty((10, 10), dtype=torch.float32) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.zeros( + R.shape([10, 10]), dtype="float32" + ) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(Empty(), example_args, {}, Expected) + + def test_keep_params(): class Conv2D1(Module): def __init__(self): From 68b206a7934fb8351144e9e24534f48ac8b41e25 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:40:22 +0900 Subject: [PATCH 15/20] support fill --- .../torch/base_fx_graph_translator.py | 7 ++++++ .../torch/exported_program_translator.py | 1 + .../test_frontend_from_exported_program.py | 23 +++++++++++++++++++ 3 files changed, 31 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index cbb0dac6f374..1a63abb1d0c5 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -851,6 +851,13 @@ def _empty(self, node: fx.Node) -> relax.Var: dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) + def _fill(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) + return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 5fabc05a76d8..952718909735 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -284,6 +284,7 @@ def create_convert_map( # tensor creation "arange.start": self._arange, "empty.memory_format": self._empty, + "fill.Scalar": self._fill, # other "getitem": self._getitem, } diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e59efcb2f0d6..b3a317031690 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3247,6 +3247,29 @@ def main( verify_model(Empty(), example_args, {}, Expected) +def test_fill(): + class Fill(Module): + def forward(self, input: torch.Tensor): + return torch.fill(input, 1.5) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.full( + R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32" + ) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(Fill(), example_args, {}, Expected) + + def test_keep_params(): class Conv2D1(Module): def __init__(self): From 52089d418efc2eb4857aca03fa0c62f1bd93ffcd Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:41:20 +0900 Subject: [PATCH 16/20] support new_ones --- .../torch/base_fx_graph_translator.py | 15 ++++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 15 ------------ .../test_frontend_from_exported_program.py | 24 +++++++++++++++++++ 4 files changed, 40 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 1a63abb1d0c5..dd259239ea63 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -858,6 +858,21 @@ def _fill(self, node: fx.Node) -> relax.Var: value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + def _new_ones(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1] if isinstance(args[1], (list, tuple)) else args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, self_var.struct_info.dtype), + self_var.struct_info.dtype, + ) + ) + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 952718909735..d80dab65490f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -285,6 +285,7 @@ def create_convert_map( "arange.start": self._arange, "empty.memory_format": self._empty, "fill.Scalar": self._fill, + "new_ones.default": self._new_ones, # other "getitem": self._getitem, } diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index daa38099fcc3..6debe9c336bc 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -485,21 +485,6 @@ def _masked_fill(self, node: fx.Node) -> relax.Var: values = self.block_builder.emit(relax.op.full_like(x, rx_value)) return self.block_builder.emit(relax.op.where(mask, values, x)) - def _new_ones(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - self_var = args[0] - size = args[1] if isinstance(args[1], (list, tuple)) else args[1:] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, self_var.struct_info.dtype), - self_var.struct_info.dtype, - ) - ) - def _ones(self, node: fx.Node) -> relax.Var: import torch diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index b3a317031690..3fc19db3881b 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3270,6 +3270,30 @@ def main( verify_model(Fill(), example_args, {}, Expected) +def test_new_ones(): + class NewOnes(Module): + def forward(self, x): + return x.new_ones(1, 2, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3), dtype="float32") = R.full( + (1, 2, 3), R.const(1, "float32"), dtype="float32" + ) + gv: R.Tuple(R.Tensor((1, 2, 3), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, dtype=torch.float32),) + verify_model(NewOnes(), example_args, {}, expected1) + + def test_keep_params(): class Conv2D1(Module): def __init__(self): From e82d5730dfd362a1c46f4d5ea5ebaeab20069a60 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:48:52 +0900 Subject: [PATCH 17/20] support _to_copy --- .../torch/base_fx_graph_translator.py | 13 +++ .../torch/exported_program_translator.py | 1 + .../test_frontend_from_exported_program.py | 94 +++++++++++++++++++ 3 files changed, 108 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index dd259239ea63..1e13f957b6d0 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -804,6 +804,19 @@ def _transpose(self, node: fx.Node) -> relax.Var: ########## Creation ########## + def _to_copy(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = self._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = self._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x + def _arange(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index d80dab65490f..b93b2c3758cd 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -282,6 +282,7 @@ def create_convert_map( ), "view.default": self._reshape, # tensor creation + "_to_copy.default": self._to_copy, "arange.start": self._arange, "empty.memory_format": self._empty, "fill.Scalar": self._fill, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 3fc19db3881b..391f54d0687c 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3294,6 +3294,100 @@ def main( verify_model(NewOnes(), example_args, {}, expected1) +def test_to_copy(): + # float + class ToFloat(Module): + def forward(self, x): + return x.float() + + @tvm.script.ir_module + class expected_float: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # half + class ToHalf(Module): + def forward(self, x): + return x.half() + + @tvm.script.ir_module + class expected_half: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x, dtype="float16") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,) + R.output(gv) + return gv + + # type + class Type(Module): + def forward(self, x): + return x.type(torch.float32) + + @tvm.script.ir_module + class expected_type: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): + # block 0 + with R.dataflow(): + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (x,) + R.output(gv) + return gv + + class To1(Module): + def forward(self, input): + return input.to(torch.float16) + + @I.ir_module + class expected_to1: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(inp_0, dtype="float16") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,) + R.output(gv) + return gv + + class To2(Module): + def forward(self, input): + return input.to("cpu") + + @I.ir_module + class expected_to2: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(inp_0, dtype="float32") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(ToFloat(), example_args, {}, expected_float) + verify_model(ToHalf(), example_args, {}, expected_half) + verify_model(Type(), example_args, {}, expected_type) + verify_model(To1(), example_args, {}, expected_to1) + verify_model(To2(), example_args, {}, expected_to2) + + def test_keep_params(): class Conv2D1(Module): def __init__(self): From e97c45949855b586177035a2b6fd780ce2975fb8 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:51:58 +0900 Subject: [PATCH 18/20] support split --- .../torch/base_fx_graph_translator.py | 13 +++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 13 ------- .../test_frontend_from_exported_program.py | 37 +++++++++++++++++++ 4 files changed, 51 insertions(+), 13 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 1e13f957b6d0..322ee04e0c20 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -783,6 +783,19 @@ def _reshape(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.reshape(x, dims)) + def _split(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + split_size = node.args[1] + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) + if isinstance(split_size, (list, tuple)): + n_section = [] + for s in split_size[:-1]: + cum_sum = 0 if not n_section else n_section[-1] + n_section.append(s + cum_sum) + else: + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + return self.block_builder.emit(relax.op.split(x, n_section, dim)) + def _squeeze(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index b93b2c3758cd..e92dec6127cf 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -273,6 +273,7 @@ def create_convert_map( "repeat.default": self._repeat, "select.int": self._select, "slice.Tensor": self._slice, + "split.Tensor": self._split, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, "tile.default": self._tile, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 6debe9c336bc..9fbc95fa7c00 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -421,19 +421,6 @@ def _size(self, node: fx.Node) -> relax.Expr: idx = node.args[1] return self.shape_of(x)[idx].value - def _split(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - split_size = node.args[1] - dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) - if isinstance(split_size, (list, tuple)): - n_section = [] - for s in split_size[:-1]: - cum_sum = 0 if not n_section else n_section[-1] - n_section.append(s + cum_sum) - else: - n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size - return self.block_builder.emit(relax.op.split(x, n_section, dim)) - ########## Creation ########## def _inplace_fill(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 391f54d0687c..ad0a476d53ab 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3034,6 +3034,43 @@ def main( verify_model(Slice2(), example_args, {}, expected2) +def test_split(): + class Chunk(Module): + def forward(self, input): + return torch.chunk(input, 3, dim=1) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=3, axis=1) + lv1: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[1] + lv3: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[2] + gv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = (lv1, lv2, lv3) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Chunk(), example_args, {}, Expected) + + def test_squeeze(): class Squeeze1(Module): def forward(self, input): From 975e0d409387b7e9f6892c486fd8f8e73b8f677d Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 12:53:49 +0900 Subject: [PATCH 19/20] add test for unbind --- .../test_frontend_from_exported_program.py | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ad0a476d53ab..a9559072043f 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3067,9 +3067,101 @@ def main( R.output(gv) return gv + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((0, 3, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] + lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] + lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv8, lv9, lv10) + R.output(gv) + return gv + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 0, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) + lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) + lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] + lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] + lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv8, lv9, lv10) + R.output(gv) + return gv + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) verify_model(Chunk(), example_args, {}, Expected) + example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),) + verify_model(Unbind1(), example_args, {}, expected1) + verify_model(Unbind2(), example_args, {}, expected2) + def test_squeeze(): class Squeeze1(Module): From 4dd7623511fb97281aa93b943b87cd85160dfd30 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 30 Sep 2024 13:10:00 +0900 Subject: [PATCH 20/20] support clone --- .../torch/exported_program_translator.py | 1 + .../test_frontend_from_exported_program.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index e92dec6127cf..1401a0bcef3a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -285,6 +285,7 @@ def create_convert_map( # tensor creation "_to_copy.default": self._to_copy, "arange.start": self._arange, + "clone.default": lambda node: self.env[node.args[0]], "empty.memory_format": self._empty, "fill.Scalar": self._fill, "new_ones.default": self._new_ones, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index a9559072043f..65890ff6971b 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3353,6 +3353,26 @@ def main( verify_model(Arange(), example_args, {}, Expected) +def test_clone(): + class Clone(Module): + def forward(self, input): + return torch.clone(input) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(Clone(), example_args, {}, Expected) + + def test_empty(): class Empty(Module): def forward(self, input):