diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f9b49204b85e..0038edf77a70 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3019,6 +3019,7 @@ def _op_dispatch(cls, operator, inputs, attr, params): op_map = { "size": cls._size, "arange": cls._arange, + "index_put": cls._index_put, "reshape": cls._reshape, "embedding_bag": cls._embedding_bag, } @@ -3038,6 +3039,47 @@ def _size(cls, inputs, attr, params): def _arange(cls, inputs, attr, params): return _op.arange(inputs[0], inputs[1], inputs[2], dtype="int64") + @classmethod + def _check_index(cls, indices, values): + def unfolding_indices(indices, values): + n = len(indices) + flatten_indices = [] + slices_size = [] + for index in indices: + flatten_indices.append(_op.reshape(index, _op.const([-1]))) + slices_size.append(infer_shape(flatten_indices[-1])[0]) + repeat_size = [1] + tile_size = [1] + for i in range(1, n): + repeat_size.append(slices_size[-i] * repeat_size[-1]) + tile_size.append(slices_size[i - 1] * tile_size[-1]) + repeat_size.reverse() + unflod_slices = [] + for i in range(n): + unflod_slices.append( + fold_constant( + _op.repeat(_op.tile(flatten_indices[i], (tile_size[i],)), repeat_size[i], 0) + ) + ) + return unflod_slices, _op.reshape(values, _op.const([-1])) + + values_shape = infer_shape(values) + if len(values_shape) != 1: + return unfolding_indices(indices, values) + return indices, values + + @classmethod + def _index_put(cls, inputs, attr, params): + in_tensor = inputs[0] + indices, values = cls._check_index(inputs[1 : len(inputs) - 2], inputs[len(inputs) - 2]) + accumulate = inputs[len(inputs) - 1].data.asnumpy() != 0 + if not accumulate: + mode = "update" + else: + mode = "add" + index_tensor = _op.stack(indices, axis=0) + return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode) + @classmethod def _reshape(cls, inputs, attr, params): return _op.reshape(inputs[0], inputs[1]) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a1d821686ed5..d6515be54985 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5045,6 +5045,81 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None verify_embedding_bag(32, 2, [3, 3]) +@tvm.testing.parametrize_targets +def test_index_put(target, dev): + class _index_put_model(torch.nn.Module): + def __init__(self, indices, values, accumulate): + super(_index_put_model, self).__init__() + self.indices = indices + self.values = values + self.accumulate = accumulate + + def forward(self, x): + return x.index_put(self.indices, self.values, self.accumulate) + + def _convert_to_onnx(model, dummy_data): + file_name = "{}.onnx".format("aten_model") + torch.onnx.export( + model, + dummy_data, + file_name, + export_params=True, + verbose=False, + opset_version=11, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + ) + onnx_model = onnx.load(file_name) + return onnx_model + + def verify_index_put(data_shape, indices, accumulate): + dummy_data = torch.ones(data_shape) + tvm_inputs = [dummy_data.numpy()] + values = torch.rand(indices[0].size()) + model = _index_put_model(indices, values, accumulate) + onnx_model = _convert_to_onnx(model, dummy_data) + torch_out = model(dummy_data) + + tvm_out = get_tvm_output_with_vm( + onnx_model, tvm_inputs, target, dev, freeze_params=True, convert_to_static=True + ) + tvm.testing.assert_allclose(torch_out.numpy(), tvm_out) + + shape = (3, 5) + xidx = torch.tensor([0, 1, 2, 2]) + yidx = torch.tensor([0, 1, 3, 4]) + verify_index_put(shape, [xidx, yidx], True) + + shape = (3, 5, 3) + xidx = torch.tensor([0, 1, 2, 2, 0]) + yidx = torch.tensor([0, 1, 3, 4, 0]) + zidx = torch.tensor([0, 1, 1, 2, 0]) + verify_index_put(shape, [xidx, yidx, zidx], False) + + def verify_index_put_slice(data_shape, value_shape, accumulate): + dummy_data = torch.ones(data_shape) + tvm_inputs = [dummy_data.numpy()] + indices = [] + index_shape = [1] * len(value_shape) + index_shape[0] = -1 + for i in range(len(value_shape)): + indices.append(torch.arange(0, value_shape[i]).reshape(tuple(index_shape))) + index_shape.pop() + values = torch.rand(value_shape) + + model = _index_put_model(indices, values, accumulate) + onnx_model = _convert_to_onnx(model, dummy_data) + torch_out = model(dummy_data) + + tvm_out = get_tvm_output_with_vm( + onnx_model, tvm_inputs, target, dev, freeze_params=True, convert_to_static=True + ) + tvm.testing.assert_allclose(torch_out.numpy(), tvm_out) + + verify_index_put_slice((3, 3), (2, 2), False) + verify_index_put_slice((2, 3, 4), (1, 2, 3), True) + verify_index_put_slice((2, 3, 4, 5), (1, 2, 3, 1), False) + + @tvm.testing.parametrize_targets def test_reverse_sequence(target, dev): def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis): @@ -5621,6 +5696,7 @@ def repeat(N, D): test_cumsum() test_wrong_input() test_aten() + test_index_put() test_reverse_sequence() test_eyelike() test_qlinearconv()