Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,27 @@ def clone(self, inputs, input_types):
data = inputs[0]
return _op.tensor.copy(data)

# Copies the elements of src tensor.
# The src tensor must be broadcastable with the dst tensor.
# It may be of a different data type
def copy(self, inputs, input_types):
dst = inputs[0]
src = inputs[1]
# 3rd bool parameter 'non_blocking' is ignored
# get input/output shapes and dtypes
dst_shape = self.infer_shape(dst)
src_shape = self.infer_shape(src)
dst_dtype = input_types[0]
src_dtype = input_types[1]
# cast src to dst dtype if necessary
out = _op.cast(src, dst_dtype) if src_dtype != dst_dtype else src
# take output shape prefix and reverse it
front_shape_len = len(dst_shape) - len(src_shape)
for dim_sz in reversed(dst_shape[:front_shape_len]):
# expand dimension and duplicate src data according to dimension size
out = _op.stack([out] * dim_sz, axis=0)
return out

def log_softmax(self, inputs, input_types):
data = inputs[0]
axis = int(inputs[1])
Expand Down Expand Up @@ -2246,6 +2267,7 @@ def create_convert_map(self):
"aten::view": self.view,
"aten::reshape": self.reshape,
"aten::clone": self.clone,
"aten::copy_": self.copy,
"aten::log_softmax": self.log_softmax,
"aten::sigmoid": self.sigmoid,
"aten::softplus": self.softplus,
Expand Down
19 changes: 19 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,23 @@ def forward(self, *args):
verify_model(Clone1().float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_copy():
def test_fn_copy():
return lambda dst, src: dst.copy_(src)

targets = ["llvm", "cuda"]
# Copy src int32 tensor with shape (4) to dst float32 tensor with shape (3,2,4)
dst = torch.zeros((3, 2, 4), dtype=torch.float32)
src = torch.tensor([0, 1, 2, 6], dtype=torch.int32)
verify_trace_model(test_fn_copy(), [dst, src], targets)

# Copy src float32 tensor with shape (4) to dst float32 tensor with shape (4)
dst = torch.zeros((4,), dtype=torch.float32)
src = torch.tensor([0, 1, 2, 6], dtype=torch.float32)
verify_trace_model(test_fn_copy(), [dst, src], targets)


@tvm.testing.uses_gpu
def test_forward_gather():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -3718,6 +3735,7 @@ def test_fn(x, mask):
test_forward_true_divide()
test_forward_is_floating_point()
test_forward_clone()
test_forward_copy()
test_forward_softplus()
test_forward_softsign()
test_forward_logsoftmax()
Expand Down Expand Up @@ -3780,6 +3798,7 @@ def test_fn(x, mask):
test_forward_unbind()
test_forward_nonzero()
test_forward_scatter()
test_forward_index_put()
test_numel()
test_bincount()
test_cumsum()
Expand Down