From 97ada56abedc9dfd8262594681c7d68c68824550 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 30 Mar 2025 11:58:50 -0400 Subject: [PATCH 1/3] arange.default ok --- .../torch/exported_program_translator.py | 6 +++-- .../relax/test_from_exported_to_cuda.py | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0f1dc11787da..684cf344cdd4 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -389,15 +389,16 @@ def create_convert_map( "reshape.default": self._reshape, # tensor creation "_to_copy.default": self._to_copy, - "lift_fresh_copy.default": self._to_copy, + "arange.default": self._arange, + "arange.start": self._arange, "detach.default": self._detach, "detach_.default": self._detach, - "arange.start": self._arange, "contiguous.default": lambda node: self.env[node.args[0]], # no-op "clone.default": lambda node: self.env[node.args[0]], "empty.memory_format": self._empty, "empty_like.default": self._empty_like, "fill.Scalar": self._fill, + "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, # other @@ -490,6 +491,7 @@ def from_exported_program( assert ( func_name in self.convert_map ), f"Unsupported function type {func_name}" + print('found function!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!', func_name) self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 64babdc43a5c..e77ed464050a 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -466,6 +466,32 @@ def forward(self, x): torch_module = ChunkModel(chunks=chunks, dim=dim).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_arange_default(target, dev): + raw_data = np.random.rand(5).astype("int64") + + class ArangeModel(nn.Module): + def forward(self, x): + return x + torch.arange(5) + + torch_module = ArangeModel().eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +# TODO +# @tvm.testing.parametrize_targets("cuda") +# def test_arange_start_step(target, dev): +# raw_data = np.random.rand(3).astype("int64") + +# class ArangeModel(nn.Module): +# def forward(self, x): +# return x + torch.arange(1, 2.5, 0.5) + +# torch_module = ArangeModel().eval() + +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + if __name__ == "__main__": tvm.testing.main() From 3431e183fee7f1a3b1f686529edda0e24b1cfc2c Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 30 Mar 2025 12:03:29 -0400 Subject: [PATCH 2/3] all arange tests pass --- .../torch/exported_program_translator.py | 3 +- .../relax/test_from_exported_to_cuda.py | 31 +++++++++++++------ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 684cf344cdd4..7cc071d01e33 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -390,7 +390,8 @@ def create_convert_map( # tensor creation "_to_copy.default": self._to_copy, "arange.default": self._arange, - "arange.start": self._arange, + "arange.start": self._arange, # TODO test + "arange.start_step": self._arange, "detach.default": self._detach, "detach_.default": self._detach, "contiguous.default": lambda node: self.env[node.args[0]], # no-op diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index e77ed464050a..9a0cd2c67cda 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -468,7 +468,7 @@ def forward(self, x): @tvm.testing.parametrize_targets("cuda") def test_arange_default(target, dev): - raw_data = np.random.rand(5).astype("int64") + raw_data = np.array([0,0,0,0,0]) class ArangeModel(nn.Module): def forward(self, x): @@ -478,19 +478,30 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_arange_start(target, dev): + raw_data = np.array([0,0,0]) + + class ArangeModel(nn.Module): + def forward(self, x): + return x + torch.arange(1, 4) + + torch_module = ArangeModel().eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -# TODO -# @tvm.testing.parametrize_targets("cuda") -# def test_arange_start_step(target, dev): -# raw_data = np.random.rand(3).astype("int64") -# class ArangeModel(nn.Module): -# def forward(self, x): -# return x + torch.arange(1, 2.5, 0.5) +@tvm.testing.parametrize_targets("cuda") +def test_arange_start_step(target, dev): + raw_data = np.array([0.0,0.0,0.0], dtype=np.float32) -# torch_module = ArangeModel().eval() + class ArangeModel(nn.Module): + def forward(self, x): + return x + torch.arange(1, 2.5, 0.5, dtype=torch.float32) -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + torch_module = ArangeModel().eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) if __name__ == "__main__": From 5c0aa4f824bb63dfe5c104525b9a6ebcdaf5180d Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 30 Mar 2025 12:06:25 -0400 Subject: [PATCH 3/3] arange test complete --- .../torch/exported_program_translator.py | 3 +- .../relax/test_from_exported_to_cuda.py | 32 ++++++++----------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7cc071d01e33..c0052280433b 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -390,7 +390,7 @@ def create_convert_map( # tensor creation "_to_copy.default": self._to_copy, "arange.default": self._arange, - "arange.start": self._arange, # TODO test + "arange.start": self._arange, "arange.start_step": self._arange, "detach.default": self._detach, "detach_.default": self._detach, @@ -492,7 +492,6 @@ def from_exported_program( assert ( func_name in self.convert_map ), f"Unsupported function type {func_name}" - print('found function!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!', func_name) self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 9a0cd2c67cda..5a49c6f5f434 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -466,41 +466,37 @@ def forward(self, x): torch_module = ChunkModel(chunks=chunks, dim=dim).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") -def test_arange_default(target, dev): - raw_data = np.array([0,0,0,0,0]) +def test_arange(target, dev): + # arange.default + raw_data = np.array([0, 0, 0, 0, 0]) - class ArangeModel(nn.Module): + class ArangeDefaultModel(nn.Module): def forward(self, x): return x + torch.arange(5) - torch_module = ArangeModel().eval() - + torch_module = ArangeDefaultModel().eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -@tvm.testing.parametrize_targets("cuda") -def test_arange_start(target, dev): - raw_data = np.array([0,0,0]) + # arange.start + raw_data = np.array([0, 0, 0]) - class ArangeModel(nn.Module): + class ArangeStartModel(nn.Module): def forward(self, x): return x + torch.arange(1, 4) - torch_module = ArangeModel().eval() - + torch_module = ArangeStartModel().eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + # arange.start_step + raw_data = np.array([0.0, 0.0, 0.0], dtype=np.float32) -@tvm.testing.parametrize_targets("cuda") -def test_arange_start_step(target, dev): - raw_data = np.array([0.0,0.0,0.0], dtype=np.float32) - - class ArangeModel(nn.Module): + class ArangeStartStopModel(nn.Module): def forward(self, x): return x + torch.arange(1, 2.5, 0.5, dtype=torch.float32) - torch_module = ArangeModel().eval() - + torch_module = ArangeStartStopModel().eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)