From 1cf2b776166f867223b68934fa0d4f2fd163cc5e Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Tue, 29 Apr 2025 15:36:52 +0000 Subject: [PATCH 1/4] add op support for new_zeros op --- .../torch/base_fx_graph_translator.py | 15 ++++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 1 + .../test_frontend_from_exported_program.py | 24 +++++++++++++++++++ tests/python/relax/test_frontend_from_fx.py | 23 ++++++++++++++++++ 5 files changed, 64 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index ff5e51da0b25..250246aed6e6 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1548,6 +1548,21 @@ def _new_ones(self, node: fx.Node) -> relax.Var: self_var.struct_info.dtype, ) ) + + def _new_zeros(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1] if isinstance(args[1], (list, tuple)) else args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(0, self_var.struct_info.dtype), + self_var.struct_info.dtype, + ) + ) def _ones(self, node: fx.Node) -> relax.Var: import torch diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 86f5de5f367c..0bafe4e879cb 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -482,6 +482,7 @@ def create_convert_map( "linspace.default": self._linspace, "masked_fill.Scalar": self._masked_fill, "new_ones.default": self._new_ones, + "new_zeros.default": self._new_zeros, "one_hot.default": self._one_hot, "ones.default": self._ones, "ones_like.default": lambda node: self.block_builder.emit( diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 07b3df20aa72..39e562e06a5b 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -829,6 +829,7 @@ def create_convert_map( "masked_fill": self._masked_fill, "masked_scatter": self._masked_scatter, "new_ones": self._new_ones, + "new_zeros": self._new_zeros, "ones": self._ones, "one_hot": self._one_hot, "ones_like": lambda node: self.block_builder.emit( diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index dd1869a23c63..15b8c6e2e299 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3751,6 +3751,30 @@ def main( verify_model(NewOnes(), example_args, {}, expected1) +def test_new_zeros(): + class NewZeros(Module): + def forward(self, x): + return x.new_zeros(1, 2, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3), dtype="float32") = R.full( + (1, 2, 3), R.const(0, "float32"), dtype="float32" # Changed to 0 + ) + gv: R.Tuple(R.Tensor((1, 2, 3), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, dtype=torch.float32),) + verify_model(NewZeros(), example_args, {}, expected1) + + def test_to_copy(): # float class ToFloat(Module): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index f60f158cbfa4..945e06f3d063 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3327,6 +3327,29 @@ def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype=" verify_model(NewOnes(), input_info, {}, expected1) +def test_new_zeros(): + input_info = [([1, 2, 3], "float32")] + + class NewZeros(Module): + def forward(self, x): + return x.new_zeros(1, 2, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3), dtype="float32") = R.full( + (1, 2, 3), R.const(0, "float32"), dtype="float32" # Changed to 0 + ) + gv: R.Tensor((1, 2, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(NewZeros(), input_info, {}, expected1) + + def test_expand(): input_info = [([1, 2, 3, 4], "float32")] From 90a19f6a4b8a09249c4c9286eaad0296809b8e96 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Wed, 30 Apr 2025 04:22:26 +0000 Subject: [PATCH 2/4] lint check and unity issue --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 250246aed6e6..4b0133195d98 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1534,6 +1534,7 @@ def _masked_fill(self, node: fx.Node) -> relax.Var: values = self.block_builder.emit(relax.op.full_like(x, rx_value)) return self.block_builder.emit(relax.op.where(mask, values, x)) + #new-zeros op def _new_ones(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) self_var = args[0] From 30f5cb7aac9d8c33cd1e9a998be0f1ab2cdfaad1 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Wed, 30 Apr 2025 08:44:37 +0000 Subject: [PATCH 3/4] changed test cases --- .../torch/base_fx_graph_translator.py | 11 ++++------- .../test_frontend_from_exported_program.py | 19 ++++++++++--------- tests/python/relax/test_frontend_from_fx.py | 16 ++++++++-------- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4b0133195d98..170e62630bea 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1534,7 +1534,6 @@ def _masked_fill(self, node: fx.Node) -> relax.Var: values = self.block_builder.emit(relax.op.full_like(x, rx_value)) return self.block_builder.emit(relax.op.where(mask, values, x)) - #new-zeros op def _new_ones(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) self_var = args[0] @@ -1552,16 +1551,14 @@ def _new_ones(self, node: fx.Node) -> relax.Var: def _new_zeros(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - self_var = args[0] - size = args[1] if isinstance(args[1], (list, tuple)) else args[1:] - if not isinstance(size, (list, tuple)): - size = (size,) + input_tensor = args[0] + size = args[1] if isinstance(args[1], (list, tuple)) else (args[1],) if len(args[1:]) == 1 else args[1:] size = relax.ShapeExpr(size) return self.block_builder.emit( relax.op.full( size, - relax.const(0, self_var.struct_info.dtype), - self_var.struct_info.dtype, + relax.const(0, input_tensor.struct_info.dtype), + input_tensor.struct_info.dtype, ) ) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 15b8c6e2e299..03a8f38f03d5 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3752,26 +3752,27 @@ def main( def test_new_zeros(): - class NewZeros(Module): + class NewZeros(torch.nn.Module): def forward(self, x): - return x.new_zeros(1, 2, 3) + return x.new_zeros(1, 128, 128) @tvm.script.ir_module class expected1: @R.function def main( - x: R.Tensor((1, 2, 3), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")): - # block 0 + x: R.Tensor((1, 128, 128), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 128, 128), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 2, 3), dtype="float32") = R.full( - (1, 2, 3), R.const(0, "float32"), dtype="float32" # Changed to 0 + lv: R.Tensor((1, 128, 128), dtype="float32") = R.full( + R.shape([1, 128, 128]), + R.const(0, "float32"), + dtype="float32" ) - gv: R.Tuple(R.Tensor((1, 2, 3), dtype="float32")) = (lv,) + gv: R.Tuple(R.Tensor((1, 128, 128), dtype="float32")) = (lv,) R.output(gv) return gv - example_args = (torch.randn(1, 2, 3, dtype=torch.float32),) + example_args = (torch.randn(1, 128, 128, dtype=torch.float32),) verify_model(NewZeros(), example_args, {}, expected1) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 945e06f3d063..56f60ed2d4e8 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3328,26 +3328,26 @@ def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype=" def test_new_zeros(): - input_info = [([1, 2, 3], "float32")] + input_info = [([1, 128, 128], "float32")] class NewZeros(Module): def forward(self, x): - return x.new_zeros(1, 2, 3) + return x.new_zeros(1, 128, 128) @tvm.script.ir_module - class expected1: + class expected: @R.function - def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype="float32"): + def main(x: R.Tensor((1, 128, 128), dtype="float32")) -> R.Tensor((1, 128, 128), dtype="float32"): # block 0 with R.dataflow(): - lv: R.Tensor((1, 2, 3), dtype="float32") = R.full( - (1, 2, 3), R.const(0, "float32"), dtype="float32" # Changed to 0 + lv: R.Tensor((1, 128, 128), dtype="float32") = R.full( + (1, 128, 128), R.const(0.0, "float32"), dtype="float32" ) - gv: R.Tensor((1, 2, 3), dtype="float32") = lv + gv: R.Tensor((1, 128, 128), dtype="float32") = lv R.output(gv) return gv - verify_model(NewZeros(), input_info, {}, expected1) + verify_model(NewZeros(), input_info, {}, expected) def test_expand(): From e90cb282225be1f9a92c9006ff132043aed0f384 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Wed, 30 Apr 2025 09:51:42 +0000 Subject: [PATCH 4/4] fixed lint issues --- .../relax/frontend/torch/base_fx_graph_translator.py | 10 ++++++++-- .../relax/test_frontend_from_exported_program.py | 4 +--- tests/python/relax/test_frontend_from_fx.py | 4 +++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 170e62630bea..7c9ab43cf98e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1548,11 +1548,17 @@ def _new_ones(self, node: fx.Node) -> relax.Var: self_var.struct_info.dtype, ) ) - + def _new_zeros(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) input_tensor = args[0] - size = args[1] if isinstance(args[1], (list, tuple)) else (args[1],) if len(args[1:]) == 1 else args[1:] + size = ( + args[1] + if isinstance(args[1], (list, tuple)) + else (args[1],) + if len(args[1:]) == 1 + else args[1:] + ) size = relax.ShapeExpr(size) return self.block_builder.emit( relax.op.full( diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 03a8f38f03d5..8e9dacf57436 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3764,9 +3764,7 @@ def main( ) -> R.Tuple(R.Tensor((1, 128, 128), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 128, 128), dtype="float32") = R.full( - R.shape([1, 128, 128]), - R.const(0, "float32"), - dtype="float32" + R.shape([1, 128, 128]), R.const(0, "float32"), dtype="float32" ) gv: R.Tuple(R.Tensor((1, 128, 128), dtype="float32")) = (lv,) R.output(gv) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 56f60ed2d4e8..354a3080cc20 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3337,7 +3337,9 @@ def forward(self, x): @tvm.script.ir_module class expected: @R.function - def main(x: R.Tensor((1, 128, 128), dtype="float32")) -> R.Tensor((1, 128, 128), dtype="float32"): + def main( + x: R.Tensor((1, 128, 128), dtype="float32") + ) -> R.Tensor((1, 128, 128), dtype="float32"): # block 0 with R.dataflow(): lv: R.Tensor((1, 128, 128), dtype="float32") = R.full(