diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 97ccc6393cbb..50c29397d7ee 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -390,9 +390,11 @@ def create_convert_map( "reshape.default": self._reshape, # tensor creation "_to_copy.default": self._to_copy, + "arange.default": self._arange, + "arange.start": self._arange, + "arange.start_step": self._arange, "detach.default": self._detach, "detach_.default": self._detach, - "arange.start": self._arange, "contiguous.default": lambda node: self.env[node.args[0]], # no-op "clone.default": lambda node: self.env[node.args[0]], "empty.memory_format": self._empty, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 19b8f80a2390..56ee527caf09 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -467,6 +467,39 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_arange(target, dev): + # arange.default + raw_data = np.array([0, 0, 0, 0, 0]) + + class ArangeDefaultModel(nn.Module): + def forward(self, x): + return x + torch.arange(5) + + torch_module = ArangeDefaultModel().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + # arange.start + raw_data = np.array([0, 0, 0]) + + class ArangeStartModel(nn.Module): + def forward(self, x): + return x + torch.arange(1, 4) + + torch_module = ArangeStartModel().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + # arange.start_step + raw_data = np.array([0.0, 0.0, 0.0], dtype=np.float32) + + class ArangeStartStopModel(nn.Module): + def forward(self, x): + return x + torch.arange(1, 2.5, 0.5, dtype=torch.float32) + + torch_module = ArangeStartStopModel().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_index_select(target, dev): class IndexSelectModel(nn.Module):