From 2be6c16abb8ddbd0b37d236cead91ed29bcc6bbe Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 May 2021 19:57:24 +0900 Subject: [PATCH 01/17] Add GatherND batch_dim support --- include/tvm/relay/attrs/transform.h | 9 +++++++++ include/tvm/topi/transform.h | 9 ++++++--- python/tvm/relay/frontend/onnx.py | 8 ++++++++ python/tvm/relay/op/transform.py | 4 ++-- src/relay/op/tensor/transform.cc | 16 ++++++++++++---- 5 files changed, 37 insertions(+), 9 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 723f9ecdab90..72842c1b2c98 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -144,6 +144,15 @@ struct GatherAttrs : public tvm::AttrsNode { } }; +struct GatherNDAttrs : public tvm::AttrsNode { + Integer batch_dim; + + TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") { + TVM_ATTR_FIELD(batch_dim) + .set_default(Integer(0)) + .describe("The number of batch dimensions."); + } +}; struct TakeAttrs : public tvm::AttrsNode { Integer batch_dims; Integer axis; diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index b2132b75fab9..08c9e327d7d8 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1243,8 +1243,8 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, * * \return A Tensor whose op member is the gather_nd operation */ -inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string name = "T_gather_nd", - std::string tag = kInjective) { +inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim = 0, + std::string name = "T_gather_nd", std::string tag = kInjective) { size_t ndim_d = data->shape.size(); size_t ndim_i = indices->shape.size(); ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions"; @@ -1255,7 +1255,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n for (size_t i = 1; i < ndim_i; ++i) { out_shape.push_back(indices->shape[i]); } - for (size_t i = indices_dim0; i < ndim_d; ++i) { + for (size_t i = indices_dim0 + batch_dim; i < ndim_d; ++i) { out_shape.push_back(data->shape[i]); } return compute( @@ -1267,6 +1267,9 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n indices_position.push_back(out_index[i]); } Array real_indices; + for (size_t i = 0; i < batch_dim; ++i) { + real_indices.push_back(out_index[i]); + } for (size_t i = 0; i < indices_dim0; ++i) { indices_position.Set(0, make_const(DataType::Int(32), i)); if (indices->dtype.is_int()) { diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b9fabdebb330..6743cfeae7ca 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1419,6 +1419,14 @@ def _impl_v1(cls, inputs, attr, params): indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1))) return _op.gather_nd(inputs[0], indices) + @classmethod + def _impl_v12(cls, inputs, attr, params): + indices_shape = infer_shape(inputs[1]) + indices_dims = len(indices_shape) + indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1))) + batch_dim = attr.get('batch_dims', 0) + return _op.gather_nd(inputs[0], indices, batch_dim) + class Scatter(OnnxOpConverter): """Operator converter for Scatter.""" diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 8744e7b5c6ad..6ce110adf77a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1072,7 +1072,7 @@ def gather(data, axis, indices): return _make.gather(data, axis, indices) -def gather_nd(data, indices): +def gather_nd(data, indices, batch_dim=0): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1101,7 +1101,7 @@ def gather_nd(data, indices): indices = [[0, 1], [1, 0]] relay.gather_nd(data, indices) = [[3, 4], [5, 6]] """ - return _make.gather_nd(data, indices) + return _make.gather_nd(data, indices, batch_dim) def sequence_mask(data, valid_length, mask_value=0, axis=0): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index df60aeb16bf3..424e683ec751 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3350,21 +3350,29 @@ bool GatherNDRel(const Array& types, int num_inputs, const Attrs& attrs, const size_t kdim = indices->shape.size() - 1; ICHECK(size_t(mdim->value) <= ndim) << "GatherND: indices shape does satisfy."; + const auto param = attrs.as(); + ICHECK(param != nullptr); + Array oshape; for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]); - for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]); + for (size_t i = mdim->value + param->batch_dim->value; i < ndim; ++i) + oshape.push_back(data->shape[i]); reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } Array GatherNDCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return {topi::gather_nd(inputs[0], inputs[1])}; + const auto* param = attrs.as(); + ICHECK(param); + return {topi::gather_nd(inputs[0], inputs[1], param->batch_dim)}; } -Expr MakeGatherND(Expr data, Expr indices) { +Expr MakeGatherND(Expr data, Expr indices, int batch_dim = 0) { static const Op& op = Op::Get("gather_nd"); - return Call(op, {data, indices}, {}); + auto attrs = make_object(); + attrs->batch_dim = std::move(batch_dim); + return Call(op, {data, indices}, Attrs(attrs)); } TVM_REGISTER_GLOBAL("relay.op._make.gather_nd").set_body_typed(MakeGatherND); From a0d28e2f3be22b0d8917bc627d333e33dc103f92 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 May 2021 21:11:44 +0900 Subject: [PATCH 02/17] adding tests --- python/tvm/relay/frontend/onnx.py | 2 +- tests/python/relay/test_op_level3.py | 30 ++++++++++++++++++++-------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6743cfeae7ca..b86ab0093ebc 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1424,7 +1424,7 @@ def _impl_v12(cls, inputs, attr, params): indices_shape = infer_shape(inputs[1]) indices_dims = len(indices_shape) indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1))) - batch_dim = attr.get('batch_dims', 0) + batch_dim = attr.get("batch_dims", 0) return _op.gather_nd(inputs[0], indices, batch_dim) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index e84b22b30ce1..8903dc498ce0 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1252,25 +1252,38 @@ def verify_gather(data, axis, indices, ref_res): @tvm.testing.uses_gpu def test_gather_nd(): - def verify_gather_nd(xshape, yshape, y_data): + def verify_gather_nd(xshape, yshape, y_data, batch_dim=0): x = relay.var("x", relay.TensorType(xshape, "float32")) y = relay.var("y", relay.TensorType(yshape, "int32")) - z = relay.gather_nd(x, y) + z = relay.gather_nd(x, y, batch_dim) func = relay.Function([x, y], z) x_data = np.random.uniform(size=xshape).astype("float32") - ref_res = x_data[tuple(y_data)] + + if batch_dim > 0: + res = [] + for row, ind in zip(x_data, np.transpose(y_data)): + res.append(row[ind]) + ref_res = np.vstack(res) + print(ref_res.shape) + else: + ref_res = x_data[tuple(y_data)] for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, device=dev, target=target) op_res = intrp.evaluate(func)(x_data, y_data) + print(op_res.shape) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) - verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]]) - verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]]) - verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]]) + # verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) + # verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]]) + # verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]]) + # verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]]) + + # verify_gather_nd((2, 2, 2), (1, 2), [[1, 0]], 1) + # verify_gather_nd((2, 2, 2), (1, 1, 2), [[[1, 0]]], 1) + verify_gather_nd((2, 2, 2), (2, 1, 2), [[[1, 0]], [[0, 1]]], 1) def _verify_infiniteness_ops(relay_op, ref_op): @@ -1970,4 +1983,5 @@ def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): if __name__ == "__main__": - pytest.main([__file__]) + # pytest.main([__file__]) + test_gather_nd() From 33a7ab770b297ed7b85da209f33f59dac2c8d23d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 19 May 2021 22:54:58 +0900 Subject: [PATCH 03/17] test working --- tests/python/relay/test_op_level3.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 8903dc498ce0..313ef32e98d7 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1259,13 +1259,15 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dim=0): func = relay.Function([x, y], z) x_data = np.random.uniform(size=xshape).astype("float32") + y_data = np.array(y_data).astype(np.int32) if batch_dim > 0: res = [] - for row, ind in zip(x_data, np.transpose(y_data)): - res.append(row[ind]) - ref_res = np.vstack(res) - print(ref_res.shape) + axes = (tuple(range(1, len(y_data.shape))) + (0,)) + swapped = np.transpose(y_data, axes) + for row, ind in zip(x_data, swapped): + res.append(row[tuple(ind.T)]) + ref_res = np.stack(res, 0) else: ref_res = x_data[tuple(y_data)] @@ -1273,7 +1275,6 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dim=0): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, device=dev, target=target) op_res = intrp.evaluate(func)(x_data, y_data) - print(op_res.shape) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) # verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) @@ -1281,9 +1282,9 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dim=0): # verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]]) # verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]]) - # verify_gather_nd((2, 2, 2), (1, 2), [[1, 0]], 1) - # verify_gather_nd((2, 2, 2), (1, 1, 2), [[[1, 0]]], 1) - verify_gather_nd((2, 2, 2), (2, 1, 2), [[[1, 0]], [[0, 1]]], 1) + verify_gather_nd((2, 2, 2), (1, 2), [[1, 0]], 1) + verify_gather_nd((2, 2, 2), (1, 2, 1), [[[1], [0]]], 1) + verify_gather_nd((2, 2, 2), (2, 2, 1), [[[1], [0]], [[0], [1]]], 1) def _verify_infiniteness_ops(relay_op, ref_op): From 466a113b730164f56be5a951e6c79dc999d4cefb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 19 May 2021 23:06:33 +0900 Subject: [PATCH 04/17] improved reference code --- tests/python/relay/test_op_level3.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 313ef32e98d7..416360f588cc 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1263,10 +1263,9 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dim=0): if batch_dim > 0: res = [] - axes = (tuple(range(1, len(y_data.shape))) + (0,)) - swapped = np.transpose(y_data, axes) - for row, ind in zip(x_data, swapped): - res.append(row[tuple(ind.T)]) + for i, row in enumerate(x_data): + indices = y_data[:, i] # the indices for the i-th batch + res.append(row[tuple(indices)]) ref_res = np.stack(res, 0) else: ref_res = x_data[tuple(y_data)] From 9875f45dcdf0b633074396b9bdb3bfca4f810667 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 May 2021 06:03:05 +0900 Subject: [PATCH 05/17] refactor ref func --- tests/python/relay/test_op_level3.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 416360f588cc..21b73788d428 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1261,12 +1261,16 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dim=0): x_data = np.random.uniform(size=xshape).astype("float32") y_data = np.array(y_data).astype(np.int32) - if batch_dim > 0: + def gather_nd_batch_dim_1_ref(data, indices): res = [] - for i, row in enumerate(x_data): - indices = y_data[:, i] # the indices for the i-th batch - res.append(row[tuple(indices)]) - ref_res = np.stack(res, 0) + for i, row in enumerate(data): + indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch + res.append(row[indices_tuple]) + # stack on the batch dim + return np.stack(res, 0) + + if batch_dim > 0: + ref_res = gather_nd_batch_dim_1_ref(x_data, y_data) else: ref_res = x_data[tuple(y_data)] From 3ad2880cf5ccf8c332fafbb3e1a3f69722f8adf5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 May 2021 06:51:34 +0900 Subject: [PATCH 06/17] batch dim 2 tests from tf all passed --- tests/python/relay/test_op_level3.py | 47 +++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 21b73788d428..5f0860ccf208 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1258,8 +1258,13 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dim=0): z = relay.gather_nd(x, y, batch_dim) func = relay.Function([x, y], z) + x_data = np.random.uniform(size=xshape).astype("float32") - y_data = np.array(y_data).astype(np.int32) + + if y_data: + y_data = np.array(y_data, dtype="int32") + else: + y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32") def gather_nd_batch_dim_1_ref(data, indices): res = [] @@ -1269,7 +1274,15 @@ def gather_nd_batch_dim_1_ref(data, indices): # stack on the batch dim return np.stack(res, 0) - if batch_dim > 0: + if batch_dim > 1: + x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dim:]) + y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dim + 1) :]) + + ref_res = gather_nd_batch_dim_1_ref(x_data_reshape, y_data_reshape) + + out_shape = yshape[1 : (batch_dim + 1)] + ref_res.shape[1:] + ref_res = np.reshape(ref_res, out_shape) + elif batch_dim == 1: ref_res = gather_nd_batch_dim_1_ref(x_data, y_data) else: ref_res = x_data[tuple(y_data)] @@ -1280,15 +1293,34 @@ def gather_nd_batch_dim_1_ref(data, indices): op_res = intrp.evaluate(func)(x_data, y_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - # verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) - # verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]]) - # verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]]) - # verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]]) + verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) + verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]]) + verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]]) + verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]]) + # Examples from tensorflow gather_nd doc + # https://www.tensorflow.org/api_docs/python/tf/gather_nd verify_gather_nd((2, 2, 2), (1, 2), [[1, 0]], 1) verify_gather_nd((2, 2, 2), (1, 2, 1), [[[1], [0]]], 1) verify_gather_nd((2, 2, 2), (2, 2, 1), [[[1], [0]], [[0], [1]]], 1) + # Test cases from tensorflow gather_nd tests kernel_tests/array_ops_test.py + verify_gather_nd((2, 2, 2), (1, 2), None, 1) + verify_gather_nd((2, 2, 2), (2, 2), None, 1) + verify_gather_nd((2, 2, 3, 2), (3, 2), None, 1) + verify_gather_nd((2, 2, 3, 2), (2, 2), None, 1) + verify_gather_nd((2, 2, 3, 2), (1, 2), None, 1) + verify_gather_nd((2, 2, 3, 2), (3, 2, 1), None, 1) + verify_gather_nd((2, 2, 3, 2), (2, 2, 2), None, 1) + verify_gather_nd((2, 2, 3, 2), (1, 2, 3), None, 1) + + verify_gather_nd((3, 2, 2, 3, 4), (3, 3, 2), None, 2) + verify_gather_nd((3, 2, 2, 3, 4), (2, 3, 2), None, 2) + verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2), None, 2) + verify_gather_nd((3, 2, 2, 3, 4), (3, 3, 2, 1), None, 2) + verify_gather_nd((3, 2, 2, 3, 4), (2, 3, 2, 2), None, 2) + verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2, 3), None, 2) + def _verify_infiniteness_ops(relay_op, ref_op): for dtype in ["float32", "float16", "float16", "int32", "int16"]: @@ -1987,5 +2019,4 @@ def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): if __name__ == "__main__": - # pytest.main([__file__]) - test_gather_nd() + pytest.main([__file__]) From e2dc34848cafed61acb47a35677c606df3ea5057 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 May 2021 07:01:58 +0900 Subject: [PATCH 07/17] batch_dim -> batch_dims --- include/tvm/relay/attrs/transform.h | 4 ++-- include/tvm/topi/transform.h | 7 ++++--- python/tvm/relay/frontend/onnx.py | 4 ++-- python/tvm/relay/op/transform.py | 7 +++++-- src/relay/op/tensor/transform.cc | 8 ++++---- 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 72842c1b2c98..547cd78fef03 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -145,10 +145,10 @@ struct GatherAttrs : public tvm::AttrsNode { }; struct GatherNDAttrs : public tvm::AttrsNode { - Integer batch_dim; + Integer batch_dims; TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") { - TVM_ATTR_FIELD(batch_dim) + TVM_ATTR_FIELD(batch_dims) .set_default(Integer(0)) .describe("The number of batch dimensions."); } diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 08c9e327d7d8..c9ee63390c8f 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1238,12 +1238,13 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, * * \param data The source array. * \param indices The indices of the values to extract. + * \param batch_dims The number of batch dimensions. * \param name The name of the operation. * \param tag The tag to mark the operation. * * \return A Tensor whose op member is the gather_nd operation */ -inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim = 0, +inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0, std::string name = "T_gather_nd", std::string tag = kInjective) { size_t ndim_d = data->shape.size(); size_t ndim_i = indices->shape.size(); @@ -1255,7 +1256,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim for (size_t i = 1; i < ndim_i; ++i) { out_shape.push_back(indices->shape[i]); } - for (size_t i = indices_dim0 + batch_dim; i < ndim_d; ++i) { + for (size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) { out_shape.push_back(data->shape[i]); } return compute( @@ -1267,7 +1268,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim indices_position.push_back(out_index[i]); } Array real_indices; - for (size_t i = 0; i < batch_dim; ++i) { + for (size_t i = 0; i < batch_dims; ++i) { real_indices.push_back(out_index[i]); } for (size_t i = 0; i < indices_dim0; ++i) { diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b86ab0093ebc..7c8696eab1ef 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1424,8 +1424,8 @@ def _impl_v12(cls, inputs, attr, params): indices_shape = infer_shape(inputs[1]) indices_dims = len(indices_shape) indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1))) - batch_dim = attr.get("batch_dims", 0) - return _op.gather_nd(inputs[0], indices, batch_dim) + batch_dims = attr.get("batch_dims", 0) + return _op.gather_nd(inputs[0], indices, batch_dims) class Scatter(OnnxOpConverter): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 6ce110adf77a..ad132cfdc2f5 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1072,7 +1072,7 @@ def gather(data, axis, indices): return _make.gather(data, axis, indices) -def gather_nd(data, indices, batch_dim=0): +def gather_nd(data, indices, batch_dims=0): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1084,6 +1084,9 @@ def gather_nd(data, indices, batch_dim=0): indices : relay.Expr The shape of output tensor. + batch_dims : int + The number of batch dimensions + Returns ------- ret : relay.Expr @@ -1101,7 +1104,7 @@ def gather_nd(data, indices, batch_dim=0): indices = [[0, 1], [1, 0]] relay.gather_nd(data, indices) = [[3, 4], [5, 6]] """ - return _make.gather_nd(data, indices, batch_dim) + return _make.gather_nd(data, indices, batch_dims) def sequence_mask(data, valid_length, mask_value=0, axis=0): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 424e683ec751..9e49d86eb063 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3355,7 +3355,7 @@ bool GatherNDRel(const Array& types, int num_inputs, const Attrs& attrs, Array oshape; for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]); - for (size_t i = mdim->value + param->batch_dim->value; i < ndim; ++i) + for (size_t i = mdim->value + param->batch_dims->value; i < ndim; ++i) oshape.push_back(data->shape[i]); reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; @@ -3365,13 +3365,13 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i const Type& out_type) { const auto* param = attrs.as(); ICHECK(param); - return {topi::gather_nd(inputs[0], inputs[1], param->batch_dim)}; + return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; } -Expr MakeGatherND(Expr data, Expr indices, int batch_dim = 0) { +Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); - attrs->batch_dim = std::move(batch_dim); + attrs->batch_dims = std::move(batch_dims); return Call(op, {data, indices}, Attrs(attrs)); } From dffef85441bc0657e8ba857f5befe8d9937c4ab5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 May 2021 07:04:36 +0900 Subject: [PATCH 08/17] add example --- python/tvm/relay/op/transform.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index ad132cfdc2f5..7523ee06321e 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1103,6 +1103,10 @@ def gather_nd(data, indices, batch_dims=0): data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] indices = [[0, 1], [1, 0]] relay.gather_nd(data, indices) = [[3, 4], [5, 6]] + + data = [[[0,1],[2,3]],[[4,5],[6,7]]] + indices = [[1, 0]] + relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]] """ return _make.gather_nd(data, indices, batch_dims) From e0e63154bd6d00110eda7d086d80139262b89722 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 May 2021 07:14:30 +0900 Subject: [PATCH 09/17] minor change --- python/tvm/relay/op/transform.py | 2 +- tests/python/relay/test_op_level3.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 7523ee06321e..eeb8644d4328 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1085,7 +1085,7 @@ def gather_nd(data, indices, batch_dims=0): The shape of output tensor. batch_dims : int - The number of batch dimensions + The number of batch dimensions. Returns ------- diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 5f0860ccf208..a4d717322156 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1252,10 +1252,10 @@ def verify_gather(data, axis, indices, ref_res): @tvm.testing.uses_gpu def test_gather_nd(): - def verify_gather_nd(xshape, yshape, y_data, batch_dim=0): + def verify_gather_nd(xshape, yshape, y_data, batch_dims=0): x = relay.var("x", relay.TensorType(xshape, "float32")) y = relay.var("y", relay.TensorType(yshape, "int32")) - z = relay.gather_nd(x, y, batch_dim) + z = relay.gather_nd(x, y, batch_dims) func = relay.Function([x, y], z) @@ -1266,7 +1266,7 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dim=0): else: y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32") - def gather_nd_batch_dim_1_ref(data, indices): + def gather_nd_batch_dims_1_ref(data, indices): res = [] for i, row in enumerate(data): indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch @@ -1274,16 +1274,16 @@ def gather_nd_batch_dim_1_ref(data, indices): # stack on the batch dim return np.stack(res, 0) - if batch_dim > 1: - x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dim:]) - y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dim + 1) :]) + if batch_dims > 1: + x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:]) + y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1):]) - ref_res = gather_nd_batch_dim_1_ref(x_data_reshape, y_data_reshape) + ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape) - out_shape = yshape[1 : (batch_dim + 1)] + ref_res.shape[1:] + out_shape = yshape[1: (batch_dims + 1)] + ref_res.shape[1:] ref_res = np.reshape(ref_res, out_shape) - elif batch_dim == 1: - ref_res = gather_nd_batch_dim_1_ref(x_data, y_data) + elif batch_dims == 1: + ref_res = gather_nd_batch_dims_1_ref(x_data, y_data) else: ref_res = x_data[tuple(y_data)] From 039a2c4b14566bd3a092109781afc85dab71f867 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 May 2021 11:30:56 +0900 Subject: [PATCH 10/17] add onnx test --- python/tvm/relay/frontend/onnx.py | 14 ++++++------ tests/python/frontend/onnx/test_forward.py | 25 ++++++++++++++++++++-- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7c8696eab1ef..774f29cc0743 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1413,19 +1413,19 @@ def _impl_v1(cls, inputs, attr, params): class GatherND(OnnxOpConverter): """Operator converter for GatherND.""" + def _impl_common(data, indices, batch_dims=0): + indices_dims = len(infer_shape(indices)) + indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) + return _op.gather_nd(data, indices, batch_dims) + @classmethod def _impl_v1(cls, inputs, attr, params): - indices_dims = len(infer_shape(inputs[1])) - indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1))) - return _op.gather_nd(inputs[0], indices) + return cls._impl_common(inputs[0], inputs[1]) @classmethod def _impl_v12(cls, inputs, attr, params): - indices_shape = infer_shape(inputs[1]) - indices_dims = len(indices_shape) - indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1))) batch_dims = attr.get("batch_dims", 0) - return _op.gather_nd(inputs[0], indices, batch_dims) + return cls._impl_common(inputs[0], inputs[1], batch_dims) class Scatter(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index fdb8d205a244..2bd7110b1d7f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import re import numpy as np import pytest import scipy @@ -218,6 +219,12 @@ def make_constant_node(name, data_type, dims, vals): ) +def is_version_greater_than(ver): + return "".join(re.findall(r"(\d+\.)(\d+\.)(\d)", onnx.__version__)[0]) > "".join( + re.findall(r"(\d+\.)(\d+\.)(\d)", ver)[0] + ) + + @tvm.testing.uses_gpu def test_reshape(): in_shape = (4, 3, 3, 4) @@ -1002,12 +1009,16 @@ def test_isnan(): _test_finite_ops((2, 4, 5, 6), np.isnan, {}, "float32", "IsNaN", {}) -def verify_gather_nd(in_shape, indices, out_shape, dtype="float32"): +def verify_gather_nd(in_shape, indices, out_shape, dtype="float32", batch_dims=0, opset=11): x = np.random.uniform(size=in_shape).astype(dtype) indices = np.array(indices, dtype="int64") y = helper.make_node("GatherND", ["in", "indices"], ["out"]) + if opset >= 12: + batch_dims_attr = helper.make_attribute("batch_dims", batch_dims) + y.attribute.append(batch_dims_attr) + graph = helper.make_graph( [y], "gather_test", @@ -1024,7 +1035,7 @@ def verify_gather_nd(in_shape, indices, out_shape, dtype="float32"): ], ) model = helper.make_model(graph, producer_name="gather_test") - verify_with_ort_with_inputs(model, [x, indices], [out_shape]) + verify_with_ort_with_inputs(model, [x, indices], [out_shape], opset=opset) @tvm.testing.uses_gpu @@ -1034,6 +1045,16 @@ def test_gather_nd(): verify_gather_nd([2, 2, 2], [[0, 1], [1, 0]], [2, 2]) verify_gather_nd([2, 2, 2], [[[0, 1]], [[1, 0]]], [2, 1, 2]) + if is_version_greater_than("1.7.0"): + verify_gather_nd([2, 2, 2], [[1], [0]], [2, 2], batch_dims=1, opset=12) + verify_gather_nd( + (3, 2, 2, 3, 4), + np.random.randint(low=0, high=2, size=(3, 2, 3), dtype="int64"), + (3, 2), + batch_dims=2, + opset=12, + ) + @tvm.testing.uses_gpu def test_onehot(): From 4ac8cfc1c25b74a71557f36d155ec60d81c9bc71 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 May 2021 11:33:55 +0900 Subject: [PATCH 11/17] fix onnx version --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 2bd7110b1d7f..aaf524cc9dcc 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1045,7 +1045,7 @@ def test_gather_nd(): verify_gather_nd([2, 2, 2], [[0, 1], [1, 0]], [2, 2]) verify_gather_nd([2, 2, 2], [[[0, 1]], [[1, 0]]], [2, 1, 2]) - if is_version_greater_than("1.7.0"): + if is_version_greater_than("1.6.0"): verify_gather_nd([2, 2, 2], [[1], [0]], [2, 2], batch_dims=1, opset=12) verify_gather_nd( (3, 2, 2, 3, 4), From d0446154a15a701682d9c71da9fc39ec3fa5b973 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 May 2021 11:38:48 +0900 Subject: [PATCH 12/17] fix lint --- include/tvm/relay/attrs/transform.h | 4 +--- tests/python/relay/test_op_level3.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 547cd78fef03..cc97a94a1406 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -148,9 +148,7 @@ struct GatherNDAttrs : public tvm::AttrsNode { Integer batch_dims; TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") { - TVM_ATTR_FIELD(batch_dims) - .set_default(Integer(0)) - .describe("The number of batch dimensions."); + TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions."); } }; struct TakeAttrs : public tvm::AttrsNode { diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index a4d717322156..b8bab295ba67 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1276,11 +1276,11 @@ def gather_nd_batch_dims_1_ref(data, indices): if batch_dims > 1: x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:]) - y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1):]) + y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :]) ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape) - out_shape = yshape[1: (batch_dims + 1)] + ref_res.shape[1:] + out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:] ref_res = np.reshape(ref_res, out_shape) elif batch_dims == 1: ref_res = gather_nd_batch_dims_1_ref(x_data, y_data) From 4be279df0b22889c007c2b742a2ca2cd137a7733 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 May 2021 12:12:16 +0900 Subject: [PATCH 13/17] remove move on batch_dims --- src/relay/op/tensor/transform.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9e49d86eb063..137bb73ddc34 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3371,7 +3371,7 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); - attrs->batch_dims = std::move(batch_dims); + attrs->batch_dims = batch_dims; return Call(op, {data, indices}, Attrs(attrs)); } From 76dce71c077358b9267460eafbe9211f5d8ec03a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 May 2021 14:35:05 +0900 Subject: [PATCH 14/17] fix pylint --- python/tvm/relay/frontend/onnx.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 774f29cc0743..e70167a6aa57 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1413,7 +1413,8 @@ def _impl_v1(cls, inputs, attr, params): class GatherND(OnnxOpConverter): """Operator converter for GatherND.""" - def _impl_common(data, indices, batch_dims=0): + @classmethod + def _impl_common(cls, data, indices, batch_dims=0): indices_dims = len(infer_shape(indices)) indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) return _op.gather_nd(data, indices, batch_dims) From 5692540dce7fac601930b1184a00b778067fbe5b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 May 2021 14:48:10 +0900 Subject: [PATCH 15/17] fix compiler warning --- include/tvm/topi/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index c9ee63390c8f..781c1cbeb311 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1268,7 +1268,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim indices_position.push_back(out_index[i]); } Array real_indices; - for (size_t i = 0; i < batch_dims; ++i) { + for (size_t i = 0; i < static_cast(batch_dims); ++i) { real_indices.push_back(out_index[i]); } for (size_t i = 0; i < indices_dim0; ++i) { From 23ac5cf612ebd5e73ce13986da4db24db363575f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 May 2021 03:28:04 +0900 Subject: [PATCH 16/17] add shape constraint for batch_dim and update doc --- src/relay/op/tensor/transform.cc | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 137bb73ddc34..fb99fcfe5d3b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3353,6 +3353,11 @@ bool GatherNDRel(const Array& types, int num_inputs, const Attrs& attrs, const auto param = attrs.as(); ICHECK(param != nullptr); + for (int i = 0; i < param->batch_dims->value; ++i) { + ICHECK(reporter->AssertEQ( + data->shape[i], indices->shape[i + 1])); // +1 since the first axis is the index tuple + } + Array oshape; for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]); for (size_t i = mdim->value + param->batch_dims->value; i < ndim; ++i) @@ -3381,9 +3386,15 @@ RELAY_REGISTER_OP("gather_nd") .describe(R"code(Gather elements or slices from data and store to a tensor whose shape is defined by indices. -Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with -shape (M, Y_0, ..., Y_{K-1}), the output will have shape -(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N, +Optionally, batch_dims, the number of batch dimensions, can be given, whose +default value is 0. + +Let B denote batch_dims, and data, indices shape be (X_0, X_1, ..., X_{N-1}), +(M, Y_0, ..., Y_{K-1}) respectively. When B > 0, indexing will start from the B-th axis, +and it must be the case that X_0, ... X_{B-1} == Y_0, ... Y_{B-1}. + +The output will have shape +(Y_0, ..., Y_{B-1}, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), where M + B <= N. If M + B == N, output shape will simply be (Y_0, ..., Y_{K-1}). )code" TVM_ADD_FILELINE) .set_num_inputs(2) From b5dd79345014ed302feed954899edd42f81236e5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 May 2021 05:48:32 +0900 Subject: [PATCH 17/17] make the output shape doc clearer --- src/relay/op/tensor/transform.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index fb99fcfe5d3b..d6c19f9a7034 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3390,12 +3390,15 @@ Optionally, batch_dims, the number of batch dimensions, can be given, whose default value is 0. Let B denote batch_dims, and data, indices shape be (X_0, X_1, ..., X_{N-1}), -(M, Y_0, ..., Y_{K-1}) respectively. When B > 0, indexing will start from the B-th axis, -and it must be the case that X_0, ... X_{B-1} == Y_0, ... Y_{B-1}. +(M, Y_0, ..., Y_{K-1}) respectively. -The output will have shape -(Y_0, ..., Y_{B-1}, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), where M + B <= N. If M + B == N, -output shape will simply be (Y_0, ..., Y_{K-1}). +When B > 0, indexing will start from the B-th axis, and it must be the case that +X_0, ... X_{B-1} == Y_0, ... Y_{B-1}. The output will have a shape +(X_0, ..., X_{B-1}, Y_B, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), where M + B <= N. + +When B == 0 (the default case), the output shape will be (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}). + +In both cases, if M + B == N, the output shape will simply be (Y_0, ..., Y_{K-1}). )code" TVM_ADD_FILELINE) .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.")