From e13fa3d0718eb75ede16a60b138e62b48360d586 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:24:29 -0400 Subject: [PATCH 1/5] unit test --- tests/python/relax/test_from_exported_to_cuda.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 8405f48576d8..43107f015313 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -63,6 +63,21 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) +@tvm.testing.parametrize_targets("cuda") +def test_full(target, dev): + class FullModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.full((2, 3), 3.141592) + + torch_module = FullModel().eval() + + raw_data = np.random.rand(3,3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): From 5b23c30341fff3765c1226261d84f178531485b1 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:26:10 -0400 Subject: [PATCH 2/5] full.default --- .../frontend/torch/base_fx_graph_translator.py | 17 +++++++++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 17 ----------------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c9c6afd71a64..55a603e20c60 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1271,6 +1271,23 @@ def _fill(self, node: fx.Node) -> relax.Var: value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + def _full(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + return self.block_builder.emit( + relax.op.full( + size, + value, + dtype, + ) + ) + def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 875ec3b83ea8..26e73dd6b84b 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -433,6 +433,7 @@ def create_convert_map( "empty.memory_format": self._empty, "empty_like.default": self._empty_like, "fill.Scalar": self._fill, + "full.default": self._full, "index_select.default": self._index_select, "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a5b50a7d1dce..80031cd7a403 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -468,23 +468,6 @@ def _inplace_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = filled return filled - def _full(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) - return self.block_builder.emit( - relax.op.full( - size, - value, - dtype, - ) - ) - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] From 35aee297ba2ca01dbdf2695267cf1869b399a95b Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:27:32 -0400 Subject: [PATCH 3/5] linting --- tests/python/relax/test_from_exported_to_cuda.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 43107f015313..0a120aa8fb70 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -71,13 +71,14 @@ def __init__(self): def forward(self, x): return torch.full((2, 3), 3.141592) - + torch_module = FullModel().eval() - raw_data = np.random.rand(3,3).astype("float32") + raw_data = np.random.rand(3, 3).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): From 5c0e18b7b419a8194c696d9e4c5f6194af7b251a Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:31:20 -0400 Subject: [PATCH 4/5] ones ok --- .../frontend/torch/base_fx_graph_translator.py | 16 ++++++++++++++++ .../torch/exported_program_translator.py | 1 + python/tvm/relax/frontend/torch/fx_translator.py | 16 ---------------- tests/python/relax/test_from_exported_to_cuda.py | 16 ++++++++++++++++ 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 55a603e20c60..2a811fd33e1e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1308,6 +1308,22 @@ def _new_ones(self, node: fx.Node) -> relax.Var: self_var.struct_info.dtype, ) ) + + def _ones(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, dtype), + dtype, + ) + ) ########## DataType ########## diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 26e73dd6b84b..e962fbdbc696 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -438,6 +438,7 @@ def create_convert_map( "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, + "ones.default": self._ones, # datatype "to.dtype": self._to, "to.dtype_layout": self._to, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 80031cd7a403..f1b9a6d6e28c 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -510,22 +510,6 @@ def _masked_scatter(self, node: fx.Node) -> relax.Var: mask = self.block_builder.emit(relax.op.broadcast_to(mask, x.struct_info.shape)) return self.block_builder.emit(relax.op.where(mask, gathered_source, x)) - def _ones(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, dtype), - dtype, - ) - ) - def _one_hot(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] num_classes = node.args[1] if len(node.args) > 1 else node.kwargs.get("num_classes") diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 0a120aa8fb70..5a0435d44484 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -78,6 +78,22 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_ones(target, dev): + class FullModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.ones((2, 3)) + + torch_module = FullModel().eval() + + raw_data = np.random.rand(1,1).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): From 40316a073d3554495874478f14a438e02005c045 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:54:57 -0400 Subject: [PATCH 5/5] tests for ones, full, and full like work --- .../torch/base_fx_graph_translator.py | 7 ++++++- .../torch/exported_program_translator.py | 1 + .../relax/test_from_exported_to_cuda.py | 20 +++++++++++++++++-- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 2a811fd33e1e..3018b0db771d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1288,6 +1288,11 @@ def _full(self, node: fx.Node) -> relax.Var: ) ) + def _full_like(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + fill_value = relax.const(node.args[1]) + return self.block_builder.emit(relax.op.full_like(x, fill_value)) + def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] @@ -1308,7 +1313,7 @@ def _new_ones(self, node: fx.Node) -> relax.Var: self_var.struct_info.dtype, ) ) - + def _ones(self, node: fx.Node) -> relax.Var: import torch diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index e962fbdbc696..bcb8b6468f72 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -434,6 +434,7 @@ def create_convert_map( "empty_like.default": self._empty_like, "fill.Scalar": self._fill, "full.default": self._full, + "full_like.default": self._full_like, "index_select.default": self._index_select, "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 5a0435d44484..e92855885e35 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -78,6 +78,23 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + +@tvm.testing.parametrize_targets("cuda") +def test_full_like(target, dev): + class FullLike(nn.Module): + def __init__(self): + super().__init__() + self.fill_value = 7.0 + + def forward(self, x): + return torch.full_like(x, self.fill_value) + + torch_module = FullLike().eval() + raw_data = np.random.rand(2, 3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_ones(target, dev): class FullModel(nn.Module): @@ -89,12 +106,11 @@ def forward(self, x): torch_module = FullModel().eval() - raw_data = np.random.rand(1,1).astype("float32") + raw_data = np.random.rand(1, 1).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module):