From d49d7d9e3da71bfb1fec34ec6538e1ee5f6280e6 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 20 Apr 2025 00:18:55 +0800 Subject: [PATCH 1/7] Update base_fx_graph_translator.py --- .../torch/base_fx_graph_translator.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index ae4c918900ec..52d3896ff2a4 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1461,6 +1461,26 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = output return output + def _linspace(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + start = args[0] + stop = args[1] + step = args[2] + + if step != 1: + step = (stop - start) / (step - 1) + stop = stop + (step / 2) + else: + stop = start + step + + if len(args) <= 3 or args[3] is None: + import torch + dtype = self._convert_data_type(str(torch.get_default_dtype())) + else: + dtype = self._convert_data_type(args[3]) + + return self.block_builder.emit(relax.op.arange(start=start, end=stop, step=step, dtype=dtype)) + def _masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] From d5b68e5df5d1a5039bd380025a24afd1eef797e5 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 20 Apr 2025 00:20:46 +0800 Subject: [PATCH 2/7] Update exported_program_translator.py --- python/tvm/relax/frontend/torch/exported_program_translator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 932607287571..b7732eaac096 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -458,6 +458,7 @@ def create_convert_map( "full_like.default": self._full_like, "index_select.default": self._index_select, "lift_fresh_copy.default": self._to_copy, + "linspace.default": self._linspace, "masked_fill.Scalar": self._masked_fill, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, From 22277eced1b0cfee191e9b38adddd1d4df28975d Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 20 Apr 2025 00:22:22 +0800 Subject: [PATCH 3/7] Update test_frontend_from_exported_program.py --- .../test_frontend_from_exported_program.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 80c0bd5fb4f5..d08bc2c3f00f 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4377,5 +4377,26 @@ def main( verify_model(Narrow(), example_args, {}, Expected) +def test_linspace(): + class Linspace(Module): + def forward(self, input): + return torch.linspace(0, 1, steps=9, dtype=torch.float32) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((9, 9), dtype="float32") + ) -> R.Tuple(R.Tensor((9,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 0.125, dtype="float32") + gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(9, 9, dtype=torch.float32),) + verify_model(Linspace(), example_args, {}, Expected) + + if __name__ == "__main__": tvm.testing.main() From b5b2fa12ba00583a875f66a5cc4b86a974667c3d Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 24 Apr 2025 21:34:53 +0800 Subject: [PATCH 4/7] Update test_frontend_from_exported_program.py --- tests/python/relax/test_frontend_from_exported_program.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index a30759aca6ed..1373c0cccad5 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4471,7 +4471,7 @@ def main( verify_model(Narrow(), example_args, {}, Expected) - def test_item(): +def test_item(): class Item(Module): def forward(self, x): return x.item() From da55933d12450304fe3e5ff8c351a49da2ed8e97 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 24 Apr 2025 21:47:15 +0800 Subject: [PATCH 5/7] Update test_frontend_from_exported_program.py --- tests/python/relax/test_frontend_from_exported_program.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 1373c0cccad5..d9f7853e1893 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4646,7 +4646,7 @@ def test_linspace(): class Linspace(Module): def forward(self, input): return torch.linspace(0, 1, steps=9, dtype=torch.float32) - + def main( input: R.Tensor((9, 9), dtype="float32") ) -> R.Tuple(R.Tensor((9,), dtype="float32")): From b002101a453a5df9bcad0ed1cf0012136bf1ab25 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 24 Apr 2025 21:49:14 +0800 Subject: [PATCH 6/7] fix lint --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index a791a7694d9d..a67e7941c48c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1462,11 +1462,14 @@ def _linspace(self, node: fx.Node) -> relax.Var: if len(args) <= 3 or args[3] is None: import torch + dtype = self._convert_data_type(str(torch.get_default_dtype())) else: dtype = self._convert_data_type(args[3]) - return self.block_builder.emit(relax.op.arange(start=start, end=stop, step=step, dtype=dtype)) + return self.block_builder.emit( + relax.op.arange(start=start, end=stop, step=step, dtype=dtype) + ) def _masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] From 34492b6a9eae4ab0be392704878a91f4ffb673a8 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 25 Apr 2025 00:34:48 +0800 Subject: [PATCH 7/7] Update test_frontend_from_exported_program.py --- tests/python/relax/test_frontend_from_exported_program.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index d9f7853e1893..420db0b72fef 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4647,6 +4647,9 @@ class Linspace(Module): def forward(self, input): return torch.linspace(0, 1, steps=9, dtype=torch.float32) + @tvm.script.ir_module + class Expected: + @R.function def main( input: R.Tensor((9, 9), dtype="float32") ) -> R.Tuple(R.Tensor((9,), dtype="float32")):