diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index ff5e51da0b25..7c9ab43cf98e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1549,6 +1549,25 @@ def _new_ones(self, node: fx.Node) -> relax.Var: ) ) + def _new_zeros(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + size = ( + args[1] + if isinstance(args[1], (list, tuple)) + else (args[1],) + if len(args[1:]) == 1 + else args[1:] + ) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(0, input_tensor.struct_info.dtype), + input_tensor.struct_info.dtype, + ) + ) + def _ones(self, node: fx.Node) -> relax.Var: import torch diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 86f5de5f367c..0bafe4e879cb 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -482,6 +482,7 @@ def create_convert_map( "linspace.default": self._linspace, "masked_fill.Scalar": self._masked_fill, "new_ones.default": self._new_ones, + "new_zeros.default": self._new_zeros, "one_hot.default": self._one_hot, "ones.default": self._ones, "ones_like.default": lambda node: self.block_builder.emit( diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 07b3df20aa72..39e562e06a5b 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -829,6 +829,7 @@ def create_convert_map( "masked_fill": self._masked_fill, "masked_scatter": self._masked_scatter, "new_ones": self._new_ones, + "new_zeros": self._new_zeros, "ones": self._ones, "one_hot": self._one_hot, "ones_like": lambda node: self.block_builder.emit( diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index dd1869a23c63..8e9dacf57436 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3751,6 +3751,29 @@ def main( verify_model(NewOnes(), example_args, {}, expected1) +def test_new_zeros(): + class NewZeros(torch.nn.Module): + def forward(self, x): + return x.new_zeros(1, 128, 128) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 128, 128), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 128, 128), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 128, 128), dtype="float32") = R.full( + R.shape([1, 128, 128]), R.const(0, "float32"), dtype="float32" + ) + gv: R.Tuple(R.Tensor((1, 128, 128), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 128, 128, dtype=torch.float32),) + verify_model(NewZeros(), example_args, {}, expected1) + + def test_to_copy(): # float class ToFloat(Module): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index f60f158cbfa4..354a3080cc20 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3327,6 +3327,31 @@ def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype=" verify_model(NewOnes(), input_info, {}, expected1) +def test_new_zeros(): + input_info = [([1, 128, 128], "float32")] + + class NewZeros(Module): + def forward(self, x): + return x.new_zeros(1, 128, 128) + + @tvm.script.ir_module + class expected: + @R.function + def main( + x: R.Tensor((1, 128, 128), dtype="float32") + ) -> R.Tensor((1, 128, 128), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 128, 128), dtype="float32") = R.full( + (1, 128, 128), R.const(0.0, "float32"), dtype="float32" + ) + gv: R.Tensor((1, 128, 128), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(NewZeros(), input_info, {}, expected) + + def test_expand(): input_info = [([1, 2, 3, 4], "float32")]