diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 33f6ffc3132e..a67e7941c48c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1448,6 +1448,29 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = output return output + def _linspace(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + start = args[0] + stop = args[1] + step = args[2] + + if step != 1: + step = (stop - start) / (step - 1) + stop = stop + (step / 2) + else: + stop = start + step + + if len(args) <= 3 or args[3] is None: + import torch + + dtype = self._convert_data_type(str(torch.get_default_dtype())) + else: + dtype = self._convert_data_type(args[3]) + + return self.block_builder.emit( + relax.op.arange(start=start, end=stop, step=step, dtype=dtype) + ) + def _masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index aa4984c0ba90..a7464a612ba1 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -474,6 +474,7 @@ def create_convert_map( "full_like.default": self._full_like, "index_select.default": self._index_select, "lift_fresh_copy.default": self._to_copy, + "linspace.default": self._linspace, "masked_fill.Scalar": self._masked_fill, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 5ef2c27e9133..420db0b72fef 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4642,5 +4642,26 @@ def main( verify_model(Eye2(), example_args2, {}, Expected2) +def test_linspace(): + class Linspace(Module): + def forward(self, input): + return torch.linspace(0, 1, steps=9, dtype=torch.float32) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((9, 9), dtype="float32") + ) -> R.Tuple(R.Tensor((9,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 0.125, dtype="float32") + gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(9, 9, dtype=torch.float32),) + verify_model(Linspace(), example_args, {}, Expected) + + if __name__ == "__main__": tvm.testing.main()