From 6323caef74a5d434ba958d5803cfc3a9d9e52ca8 Mon Sep 17 00:00:00 2001 From: Quanfeng Li Date: Sun, 23 Apr 2023 10:43:01 +0800 Subject: [PATCH] [PyTorch] Add aten::new_zeros --- python/tvm/relay/frontend/pytorch.py | 17 +++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 7 +++++++ 2 files changed, 24 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 7e6205355aca..24feccec43ef 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -832,6 +832,22 @@ def zeros_like(self, inputs, input_types): return out + def new_zeros(self, inputs, input_types): + data = inputs[1] + + import torch + + if not isinstance(data, (_expr.Expr, list, tuple, torch.Size)): + msg = "Data type %s could not be parsed in new_zeros op" % (type(data)) + raise AssertionError(msg) + + if inputs[2] is not None: + dtype = _convert_dtype_value(inputs[2]) + else: + # if dtype is None, use the dtype of the input tensor + dtype = self.infer_type(inputs[0]) + return self.full_impl(data, 0, dtype) + def full(self, inputs, input_types): data = inputs[0] fill_value = inputs[1] @@ -3755,6 +3771,7 @@ def create_convert_map(self): "aten::zeros": self.zeros, "aten::zero_": self.zero_, "aten::zeros_like": self.zeros_like, + "aten::new_zeros": self.new_zeros, "aten::new_ones": self.new_ones, "aten::full": self.full, "aten::full_like": self.full_like, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index b5fcaaecaec5..897ebdec447f 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3348,6 +3348,13 @@ def forward(self, *args): verify_model(ZerosLike3().float().eval(), input_data=input_data) +def test_forward_new_zeros(): + def test_func(x): + return x.new_zeros((2, 3)) + + verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()]) + + @tvm.testing.uses_gpu def test_forward_full(): """test_forward_full"""