From df9b22a0ebe4007b0295e0ee153de3504130b97f Mon Sep 17 00:00:00 2001 From: Liao Jianjin Date: Wed, 1 Sep 2021 13:41:51 +0800 Subject: [PATCH 1/5] onnx:add index_put --- python/tvm/relay/frontend/onnx.py | 38 +++++++++++ tests/python/frontend/onnx/test_forward.py | 77 ++++++++++++++++++++++ 2 files changed, 115 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f9b49204b85e..a815e4b63ca0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3019,6 +3019,7 @@ def _op_dispatch(cls, operator, inputs, attr, params): op_map = { "size": cls._size, "arange": cls._arange, + "index_put": cls._index_put, "reshape": cls._reshape, "embedding_bag": cls._embedding_bag, } @@ -3038,6 +3039,43 @@ def _size(cls, inputs, attr, params): def _arange(cls, inputs, attr, params): return _op.arange(inputs[0], inputs[1], inputs[2], dtype="int64") + @classmethod + def _check_index(cls, indices, values): + def unfolding_indices(indices, values): + n = len(indices) + flatten_indices = [] + slices_size = [] + for index in indices: + flatten_indices.append(_op.reshape(index, _op.const([-1]))) + slices_size.append(infer_shape(flatten_indices[-1])[0]) + repeat_size = [1] + tile_size = [1] + for i in range(1, n): + repeat_size.append(slices_size[-i] * repeat_size[-1]) + tile_size.append(slices_size[i - 1] * tile_size[-1]) + repeat_size.reverse() + unflod_slices = [] + for i in range(n): + unflod_slices.append(fold_constant(_op.repeat(_op.tile(flatten_indices[i], (tile_size[i],)), repeat_size[i], 0))) + return unflod_slices, _op.reshape(values, _op.const([-1])) + + values_shape = infer_shape(values) + if len(values_shape) != 1: + return unfolding_indices(indices, values) + return indices, values + + @classmethod + def _index_put(cls, inputs, attr, params): + in_tensor = inputs[0] + indices, values = cls._check_index(inputs[1: -2], inputs[-2]) + accumulate = inputs[-1].data.asnumpy() != 0 + if not accumulate: + mode = 'update' + else: + mode = 'add' + index_tensor = _op.stack(indices, axis=0) + return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode) + @classmethod def _reshape(cls, inputs, attr, params): return _op.reshape(inputs[0], inputs[1]) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a1d821686ed5..fd809f574a9e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5044,6 +5044,82 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None verify_embedding_bag(10, 3, [2, 10]) verify_embedding_bag(32, 2, [3, 3]) +def test_index_put(): + class _index_put_model(torch.nn.Module): + def __init__(self, indices, values, accumulate): + super(_index_put_model, self).__init__() + self.indices = indices + self.values = values + self.accumulate = accumulate + + def forward(self, x): + return x.index_put(self.indices, self.values, self.accumulate) + + def _convert_to_onnx(model, dummy_data): + file_name = "{}.onnx".format("aten_model") + torch.onnx.export( + model, + dummy_data, + file_name, + export_params=True, + verbose=False, + opset_version=11, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + ) + onnx_model = onnx.load(file_name) + return onnx_model + + def verify_index_put(data_shape, indices, accumulate): + dummy_data = torch.ones(data_shape) + tvm_inputs = [dummy_data.numpy()] + values = torch.rand(indices[0].size()) + model = _index_put_model(indices, values, accumulate) + onnx_model = _convert_to_onnx(model, dummy_data) + torch_out = model(dummy_data) + + for target, ctx in tvm.testing.enabled_targets(): + tvm_out = get_tvm_output_with_vm( + onnx_model, tvm_inputs, target, ctx, freeze_params=True, convert_to_static=True + ) + tvm.testing.assert_allclose(torch_out.numpy(), tvm_out) + + shape = (3, 5) + xidx = torch.tensor([0, 1, 2, 2]) + yidx = torch.tensor([0, 1, 3, 4]) + verify_index_put(shape, [xidx, yidx], True) + + shape = (3, 5, 3) + xidx = torch.tensor([0, 1, 2, 2, 0]) + yidx = torch.tensor([0, 1, 3, 4, 0]) + zidx = torch.tensor([0, 1, 1, 2, 0]) + verify_index_put(shape, [xidx, yidx, zidx], False) + + + def verify_index_put_slice(data_shape, value_shape, accumulate): + dummy_data = torch.ones(data_shape) + tvm_inputs = [dummy_data.numpy()] + indices = [] + index_shape = [1] * len(value_shape) + index_shape[0] = -1 + for i in range(len(value_shape)): + indices.append(torch.arange(0, value_shape[i]).reshape(tuple(index_shape))) + index_shape.pop() + values = torch.rand(value_shape) + + model = _index_put_model(indices, values, accumulate) + onnx_model = _convert_to_onnx(model, dummy_data) + torch_out = model(dummy_data) + + for target, ctx in tvm.testing.enabled_targets(): + tvm_out = get_tvm_output_with_vm( + onnx_model, tvm_inputs, target, ctx, freeze_params=True, convert_to_static=True + ) + tvm.testing.assert_allclose(torch_out.numpy(), tvm_out) + + verify_index_put_slice((3, 3), (2, 2), False) + verify_index_put_slice((2, 3, 4), (1, 2, 3), True) + verify_index_put_slice((2, 3, 4, 5), (1, 2, 3, 1), False) + @tvm.testing.parametrize_targets def test_reverse_sequence(target, dev): @@ -5621,6 +5697,7 @@ def repeat(N, D): test_cumsum() test_wrong_input() test_aten() + test_index_put() test_reverse_sequence() test_eyelike() test_qlinearconv() From 60fe75d0a06f1d2367b17f43f45dbeee6f57d714 Mon Sep 17 00:00:00 2001 From: liaojianjin Date: Thu, 2 Sep 2021 10:08:32 +0800 Subject: [PATCH 2/5] reformat code --- python/tvm/relay/frontend/onnx.py | 12 ++++++++---- tests/python/frontend/onnx/test_forward.py | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a815e4b63ca0..93b87771008b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3056,7 +3056,11 @@ def unfolding_indices(indices, values): repeat_size.reverse() unflod_slices = [] for i in range(n): - unflod_slices.append(fold_constant(_op.repeat(_op.tile(flatten_indices[i], (tile_size[i],)), repeat_size[i], 0))) + unflod_slices.append( + fold_constant( + _op.repeat(_op.tile(flatten_indices[i], (tile_size[i],)), repeat_size[i], 0) + ) + ) return unflod_slices, _op.reshape(values, _op.const([-1])) values_shape = infer_shape(values) @@ -3067,12 +3071,12 @@ def unfolding_indices(indices, values): @classmethod def _index_put(cls, inputs, attr, params): in_tensor = inputs[0] - indices, values = cls._check_index(inputs[1: -2], inputs[-2]) + indices, values = cls._check_index(inputs[1:-2], inputs[-2]) accumulate = inputs[-1].data.asnumpy() != 0 if not accumulate: - mode = 'update' + mode = "update" else: - mode = 'add' + mode = "add" index_tensor = _op.stack(indices, axis=0) return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index fd809f574a9e..482c35506ac8 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5044,6 +5044,7 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None verify_embedding_bag(10, 3, [2, 10]) verify_embedding_bag(32, 2, [3, 3]) + def test_index_put(): class _index_put_model(torch.nn.Module): def __init__(self, indices, values, accumulate): @@ -5082,7 +5083,7 @@ def verify_index_put(data_shape, indices, accumulate): onnx_model, tvm_inputs, target, ctx, freeze_params=True, convert_to_static=True ) tvm.testing.assert_allclose(torch_out.numpy(), tvm_out) - + shape = (3, 5) xidx = torch.tensor([0, 1, 2, 2]) yidx = torch.tensor([0, 1, 3, 4]) @@ -5094,7 +5095,6 @@ def verify_index_put(data_shape, indices, accumulate): zidx = torch.tensor([0, 1, 1, 2, 0]) verify_index_put(shape, [xidx, yidx, zidx], False) - def verify_index_put_slice(data_shape, value_shape, accumulate): dummy_data = torch.ones(data_shape) tvm_inputs = [dummy_data.numpy()] From 3fc47037fa804452ec4f36bedffa15556d561991 Mon Sep 17 00:00:00 2001 From: liaojianjin Date: Thu, 2 Sep 2021 11:49:39 +0800 Subject: [PATCH 3/5] add parametrize_targets --- tests/python/frontend/onnx/test_forward.py | 23 +++++++++++----------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 482c35506ac8..0d73ec3bc6f7 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5045,7 +5045,8 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None verify_embedding_bag(32, 2, [3, 3]) -def test_index_put(): +@tvm.testing.parametrize_targets +def test_index_put(target, dev): class _index_put_model(torch.nn.Module): def __init__(self, indices, values, accumulate): super(_index_put_model, self).__init__() @@ -5078,11 +5079,10 @@ def verify_index_put(data_shape, indices, accumulate): onnx_model = _convert_to_onnx(model, dummy_data) torch_out = model(dummy_data) - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm( - onnx_model, tvm_inputs, target, ctx, freeze_params=True, convert_to_static=True - ) - tvm.testing.assert_allclose(torch_out.numpy(), tvm_out) + tvm_out = get_tvm_output_with_vm( + onnx_model, tvm_inputs, target, dev, freeze_params=True, convert_to_static=True + ) + tvm.testing.assert_allclose(torch_out.numpy(), tvm_out) shape = (3, 5) xidx = torch.tensor([0, 1, 2, 2]) @@ -5110,11 +5110,10 @@ def verify_index_put_slice(data_shape, value_shape, accumulate): onnx_model = _convert_to_onnx(model, dummy_data) torch_out = model(dummy_data) - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm( - onnx_model, tvm_inputs, target, ctx, freeze_params=True, convert_to_static=True - ) - tvm.testing.assert_allclose(torch_out.numpy(), tvm_out) + tvm_out = get_tvm_output_with_vm( + onnx_model, tvm_inputs, target, dev, freeze_params=True, convert_to_static=True + ) + tvm.testing.assert_allclose(torch_out.numpy(), tvm_out) verify_index_put_slice((3, 3), (2, 2), False) verify_index_put_slice((2, 3, 4), (1, 2, 3), True) @@ -5696,8 +5695,8 @@ def repeat(N, D): test_softplus() test_cumsum() test_wrong_input() - test_aten() test_index_put() + test_aten() test_reverse_sequence() test_eyelike() test_qlinearconv() From 2d50a4a6b6e0c4994e558f6e5100cace5ac5beb3 Mon Sep 17 00:00:00 2001 From: liaojianjin Date: Thu, 2 Sep 2021 15:44:21 +0800 Subject: [PATCH 4/5] change slice to onnx_index instance --- python/tvm/relay/frontend/onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 93b87771008b..0038edf77a70 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3071,8 +3071,8 @@ def unfolding_indices(indices, values): @classmethod def _index_put(cls, inputs, attr, params): in_tensor = inputs[0] - indices, values = cls._check_index(inputs[1:-2], inputs[-2]) - accumulate = inputs[-1].data.asnumpy() != 0 + indices, values = cls._check_index(inputs[1 : len(inputs) - 2], inputs[len(inputs) - 2]) + accumulate = inputs[len(inputs) - 1].data.asnumpy() != 0 if not accumulate: mode = "update" else: From 0e83fee0cf9491e38a6ebb087e24ea7a5776a71c Mon Sep 17 00:00:00 2001 From: liaojianjin Date: Thu, 2 Sep 2021 21:48:49 +0800 Subject: [PATCH 5/5] modify test_forward --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0d73ec3bc6f7..d6515be54985 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5695,8 +5695,8 @@ def repeat(N, D): test_softplus() test_cumsum() test_wrong_input() - test_index_put() test_aten() + test_index_put() test_reverse_sequence() test_eyelike() test_qlinearconv()