From 58280f69a6b2898e7cab95196ee9d44532f5f81a Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Tue, 23 Feb 2021 22:05:12 -0800 Subject: [PATCH] [Torch] Add copy_ operator --- python/tvm/relay/frontend/pytorch.py | 22 +++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 19 ++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 205b2aa779e6..ca0d8ac65aaf 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1298,6 +1298,27 @@ def clone(self, inputs, input_types): data = inputs[0] return _op.tensor.copy(data) + # Copies the elements of src tensor. + # The src tensor must be broadcastable with the dst tensor. + # It may be of a different data type + def copy(self, inputs, input_types): + dst = inputs[0] + src = inputs[1] + # 3rd bool parameter 'non_blocking' is ignored + # get input/output shapes and dtypes + dst_shape = self.infer_shape(dst) + src_shape = self.infer_shape(src) + dst_dtype = input_types[0] + src_dtype = input_types[1] + # cast src to dst dtype if necessary + out = _op.cast(src, dst_dtype) if src_dtype != dst_dtype else src + # take output shape prefix and reverse it + front_shape_len = len(dst_shape) - len(src_shape) + for dim_sz in reversed(dst_shape[:front_shape_len]): + # expand dimension and duplicate src data according to dimension size + out = _op.stack([out] * dim_sz, axis=0) + return out + def log_softmax(self, inputs, input_types): data = inputs[0] axis = int(inputs[1]) @@ -2246,6 +2267,7 @@ def create_convert_map(self): "aten::view": self.view, "aten::reshape": self.reshape, "aten::clone": self.clone, + "aten::copy_": self.copy, "aten::log_softmax": self.log_softmax, "aten::sigmoid": self.sigmoid, "aten::softplus": self.softplus, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index aa42b0fb84e4..b6ec4cecae8f 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1267,6 +1267,23 @@ def forward(self, *args): verify_model(Clone1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_copy(): + def test_fn_copy(): + return lambda dst, src: dst.copy_(src) + + targets = ["llvm", "cuda"] + # Copy src int32 tensor with shape (4) to dst float32 tensor with shape (3,2,4) + dst = torch.zeros((3, 2, 4), dtype=torch.float32) + src = torch.tensor([0, 1, 2, 6], dtype=torch.int32) + verify_trace_model(test_fn_copy(), [dst, src], targets) + + # Copy src float32 tensor with shape (4) to dst float32 tensor with shape (4) + dst = torch.zeros((4,), dtype=torch.float32) + src = torch.tensor([0, 1, 2, 6], dtype=torch.float32) + verify_trace_model(test_fn_copy(), [dst, src], targets) + + @tvm.testing.uses_gpu def test_forward_gather(): torch.set_grad_enabled(False) @@ -3718,6 +3735,7 @@ def test_fn(x, mask): test_forward_true_divide() test_forward_is_floating_point() test_forward_clone() + test_forward_copy() test_forward_softplus() test_forward_softsign() test_forward_logsoftmax() @@ -3780,6 +3798,7 @@ def test_fn(x, mask): test_forward_unbind() test_forward_nonzero() test_forward_scatter() + test_forward_index_put() test_numel() test_bincount() test_cumsum()