diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 2688f83c86ca..5f65f86a4303 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -835,6 +835,7 @@ def create_convert_map( "fill_": self._inplace_fill, "full": self._full, "index_select": self._index_select, + "linspace": self._linspace, "masked_fill_": self._inplace_masked_fill, "masked_fill": self._masked_fill, "masked_scatter": self._masked_scatter, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 2ab20fbb1186..490a2309aa37 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -5396,5 +5396,23 @@ def forward(self, input): ) +def test_linspace(): + import numpy as np + + class Linspace(Module): + def forward(self, input): + return torch.linspace(0, 1, steps=9) + + graph_model = fx.symbolic_trace(Linspace()) + mod = from_fx(graph_model, [([9, 9], "float32")]) + assert len(mod["main"].body.blocks) == 1 + assert len(mod["main"].body.blocks[0].bindings) == 1 + assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant) + tvm.testing.assert_allclose( + mod["main"].body.blocks[0].bindings[0].value.data.numpy(), + np.linspace(0, 1, num=9, dtype="float32"), + ) + + if __name__ == "__main__": tvm.testing.main()