From 593b188b93565f10d62b9ca3bac92419a85a57b8 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 8 Sep 2021 16:45:59 -0700 Subject: [PATCH 01/18] support negatibve indices in gather --- include/tvm/relay/attrs/transform.h | 4 ++++ include/tvm/topi/transform.h | 23 ++++++++++++++++------ python/tvm/relay/frontend/onnx.py | 7 ++++++- python/tvm/relay/op/transform.py | 8 ++++++-- src/relay/op/tensor/transform.cc | 5 +++-- tests/python/frontend/onnx/test_forward.py | 3 ++- 6 files changed, 38 insertions(+), 12 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index a8317e1e51ad..fd296e02b949 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -136,11 +136,15 @@ struct ScatterNDAttrs : public tvm::AttrsNode { struct GatherAttrs : public tvm::AttrsNode { Integer axis; + Bool support_negative_indices = Bool(false); TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherAttrs") { TVM_ATTR_FIELD(axis) .set_default(NullValue()) .describe("The axis over which to select values."); + TVM_ATTR_FIELD(support_negative_indices) + .set_default(Bool(false)) + .describe("If negative indices are supported."); } }; diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 8d1a49a4cc5f..9d324d8f3093 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1219,11 +1219,13 @@ inline Tensor dyn_tile(const Tensor& x, Array new_shape, size_t rdim, * \param indices The indices of values to gather. * \param name The name of the operation. * \param tag The tag to mark the operation. + * \param support_negative_indices If negative indices are supported * * \return A Tensor whose op member is the gather operation */ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, - std::string name = "T_gather", std::string tag = kInjective) { + bool support_negative_indices = false, std::string name = "T_gather", + std::string tag = kInjective) { size_t ndim_d = data->shape.size(); size_t ndim_i = indices->shape.size(); ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar."; @@ -1242,6 +1244,8 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, out_shape.push_back(indices->shape[i]); } + PrimExpr axis_size = data->shape[axis]; + return compute( out_shape, [&](const Array& out_index) { @@ -1252,7 +1256,13 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, Array real_indices; for (size_t i = 0; i < ndim_i; ++i) { if (i == static_cast(axis)) { - real_indices.push_back(indices(indices_position)); + PrimExpr index = indices(indices_position); + + // negative indices support is expensive so make it optional + if (support_negative_indices) { + index = indexmod(index, axis_size); + } + real_indices.push_back(index); } else { real_indices.push_back(indices_position[i]); } @@ -1302,11 +1312,12 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim } for (size_t i = 0; i < indices_dim0; ++i) { indices_position.Set(0, make_const(DataType::Int(32), i)); - if (indices->dtype.is_int()) { - real_indices.push_back(indices(indices_position)); - } else { - real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position))); + PrimExpr index = indices(indices_position); + + if (!indices->dtype.is_int()) { + index = tvm::cast(tvm::DataType::Int(32), index); } + real_indices.push_back(index); } if (real_indices.size() == ndim_d) { return data(real_indices); diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 2aba807c009f..2fb310f68215 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3550,7 +3550,12 @@ def _impl_v13(cls, inputs, attr, params): dtype=input_tensor.type_annotation.dtype, ) - loss = -relay.gather(input_tensor, axis=1, indices=relay.expand_dims(target_tensor, 1)) + loss = -relay.gather( + input_tensor, + axis=1, + indices=relay.expand_dims(target_tensor, 1), + support_negative_indices=True, + ) loss = relay.squeeze(loss, axis=[1]) expanded_target_tensor = relay.expand_dims(target_tensor, 0) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 2c299022bd6e..4ef8d28e106a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1046,7 +1046,7 @@ def reverse_reshape(data, newshape): return _make.contrib_reverse_reshape(data, list(newshape)) -def gather(data, axis, indices): +def gather(data, axis, indices, support_negative_indices=False): """Gather values along given axis from given indices. E.g. for a 3D tensor, output is computed as: @@ -1071,6 +1071,10 @@ def gather(data, axis, indices): indices: relay.Expr The indices of values to gather. + support_negative_indices: bool + If True, support indices being negative. This is slower than supporting only + positive indices. + Examples -------- .. code-block:: python @@ -1080,7 +1084,7 @@ def gather(data, axis, indices): indices = [[0, 0], [1, 0]] relay.gather(data, axis, indices) = [[1, 1], [4, 3]] """ - return _make.gather(data, axis, indices) + return _make.gather(data, axis, indices, support_negative_indices) def gather_nd(data, indices, batch_dims=0, index_rank=None): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3781107eeee1..ab0f9a515482 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3274,12 +3274,13 @@ bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs, Array GatherCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); - return {topi::gather(inputs[0], param->axis, inputs[1])}; + return {topi::gather(inputs[0], param->axis, inputs[1], param->support_negative_indices)}; } -Expr MakeGather(Expr data, Integer axis, Expr indices) { +Expr MakeGather(Expr data, Integer axis, Expr indices, Bool support_negative_indices) { auto attrs = make_object(); attrs->axis = std::move(axis); + attrs->support_negative_indices = std::move(support_negative_indices); static const Op& op = Op::Get("gather"); return Call(op, {data, indices}, Attrs(attrs), {}); } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 23f1dca0e8af..81d478c54fca 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -235,7 +235,8 @@ def verify_with_ort( def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): - from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize_static + from onnxruntime.quantization import (CalibrationDataReader, QuantType, + quantize_static) input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] From 7c19abbbe9b5e74c2bcca31be75e7abf615fb25c Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 9 Sep 2021 16:12:58 -0700 Subject: [PATCH 02/18] move check to Tensor level indexing, gathernd --- include/tvm/relay/attrs/transform.h | 4 ++++ include/tvm/te/tensor.h | 4 ++-- include/tvm/topi/transform.h | 13 +++++-------- python/tvm/relay/op/transform.py | 8 ++++++-- src/relay/op/tensor/transform.cc | 5 ++++- src/te/tensor.cc | 21 +++++++++++++++------ 6 files changed, 36 insertions(+), 19 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index fd296e02b949..4d6386b59e9d 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -151,6 +151,7 @@ struct GatherAttrs : public tvm::AttrsNode { struct GatherNDAttrs : public tvm::AttrsNode { Integer batch_dims; Optional index_rank; + Bool support_negative_indices = Bool(false); TVM_DECLARE_ATTRS(GatherNDAttrs, "relay.attrs.GatherNDAttrs") { TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions."); @@ -159,6 +160,9 @@ struct GatherNDAttrs : public tvm::AttrsNode { .describe( "The size of an indexing tuple, which is a fixed value. Only needed when the number of " "indexting tuples is dynamic."); + TVM_ATTR_FIELD(support_negative_indices) + .set_default(Bool(false)) + .describe("If negative indices are supported."); } }; diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 85677a726574..4054c51bf99d 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -131,13 +131,13 @@ class Tensor : public DataProducer { * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr operator()(Array indices) const; + TVM_DLL PrimExpr operator()(Array indices, bool support_negative_indices = false) const; /*! * \brief Take elements from the tensor * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr operator()(Array indices) const; + TVM_DLL PrimExpr operator()(Array indices, bool support_negative_indices = false) const; /*! * \brief data structure to represent a slice that fixes first k coordinates. * This is used to enable syntax sugar of Tensor[x][y][z] to get the element. diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 9d324d8f3093..de09d21733ac 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1257,17 +1257,12 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, for (size_t i = 0; i < ndim_i; ++i) { if (i == static_cast(axis)) { PrimExpr index = indices(indices_position); - - // negative indices support is expensive so make it optional - if (support_negative_indices) { - index = indexmod(index, axis_size); - } real_indices.push_back(index); } else { real_indices.push_back(indices_position[i]); } } - return data(real_indices); + return data(real_indices, support_negative_indices); }, name, tag); } @@ -1280,11 +1275,13 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, * \param batch_dims The number of batch dimensions. * \param name The name of the operation. * \param tag The tag to mark the operation. + * \param support_negative_indices If negative indices are supported * * \return A Tensor whose op member is the gather_nd operation */ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0, - std::string name = "T_gather_nd", std::string tag = kInjective) { + bool support_negative_indices = false, std::string name = "T_gather_nd", + std::string tag = kInjective) { size_t ndim_d = data->shape.size(); size_t ndim_i = indices->shape.size(); ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions"; @@ -1325,7 +1322,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim for (size_t i = ndim_i - 1; i < out_index.size(); ++i) { real_indices.push_back(out_index[i]); } - return data(real_indices); + return data(real_indices, support_negative_indices); }, name, tag); } diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 4ef8d28e106a..cd1c8e613c11 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1087,7 +1087,7 @@ def gather(data, axis, indices, support_negative_indices=False): return _make.gather(data, axis, indices, support_negative_indices) -def gather_nd(data, indices, batch_dims=0, index_rank=None): +def gather_nd(data, indices, batch_dims=0, support_negative_indices=False, index_rank=None): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1106,6 +1106,10 @@ def gather_nd(data, indices, batch_dims=0, index_rank=None): The size of an indexing tuple, which is a fixed value and the same as indices.shape[0] Only needed when other dimensions of indices are dynamic. + support_negative_indices: bool + If True, support indices being negative. This is slower than supporting only + positive indices. + Returns ------- ret : relay.Expr @@ -1127,7 +1131,7 @@ def gather_nd(data, indices, batch_dims=0, index_rank=None): indices = [[1, 0]] relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]] """ - return _make.gather_nd(data, indices, batch_dims, index_rank) + return _make.gather_nd(data, indices, batch_dims, support_negative_indices, index_rank) def sequence_mask(data, valid_length, mask_value=0, axis=0): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ab0f9a515482..e92843816cee 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3354,15 +3354,18 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i const Type& out_type) { const auto* param = attrs.as(); ICHECK(param); - return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; + return { + topi::gather_nd(inputs[0], inputs[1], param->batch_dims, param->support_negative_indices)}; } Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, + Bool support_negative_indices = Bool(0), Optional index_rank = NullValue()) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); attrs->batch_dims = batch_dims; attrs->index_rank = index_rank; + attrs->support_negative_indices = support_negative_indices; return Call(op, {data, indices}, Attrs(attrs)); } diff --git a/src/te/tensor.cc b/src/te/tensor.cc index b48f39a38627..6db3be040383 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -39,17 +39,26 @@ IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name) Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor -PrimExpr Tensor::operator()(Array indices) const { +PrimExpr Tensor::operator()(Array indices, bool support_negative_indices) const { Array arr(indices.begin(), indices.end()); - return operator()(arr); + return operator()(arr, support_negative_indices); } -PrimExpr Tensor::operator()(Array indices) const { - if (ndim() != 0) { - ICHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read " - << "ndim = " << ndim() << ", indices.size=" << indices.size(); +PrimExpr Tensor::operator()(Array indices, bool support_negative_indices) const { + Array shape = (*this)->shape; + + if (shape.size() != 0) { + ICHECK_EQ(shape.size(), indices.size()) + << "Tensor dimension mismatch in read " + << "ndim = " << ndim() << ", indices.size=" << indices.size(); } + if (support_negative_indices) { + for (int i = 0; i < shape.size(); i++) { + PrimExpr new_index = indexmod(indices[i], shape[i]); + indices.Set(i, new_index); + } + } return ProducerLoad((*this), indices); } From 160cbb4cc090220fd31b21c6c9c1d639a12b689a Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 9 Sep 2021 16:41:04 -0700 Subject: [PATCH 03/18] add test, update transform.h --- include/tvm/topi/transform.h | 2 +- tests/python/relay/test_op_level3.py | 20 ++++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index de09d21733ac..ab6f7e83b4b9 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1317,7 +1317,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim real_indices.push_back(index); } if (real_indices.size() == ndim_d) { - return data(real_indices); + return data(real_indices, support_negative_indices); } for (size_t i = ndim_i - 1; i < out_index.size(); ++i) { real_indices.push_back(out_index[i]); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index eaddd33678df..4b4f10e0c569 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -21,10 +21,8 @@ import numpy as np import pytest - import tvm import tvm.testing - from tvm import relay, te from tvm.error import TVMError from tvm.relay import create_executor, transform @@ -32,7 +30,6 @@ from utils import ref_funcs - executor_kind = tvm.testing.parameter("graph", "debug") @@ -1241,13 +1238,16 @@ def test_scatter_add(self, target, dev, ref_data, dshape, ishape, axis, dtype): ], ) def test_gather(target, dev, executor_kind, data, axis, indices, ref_res): - def verify_gather(data, axis, indices, ref_res): + def verify_gather(data, axis, indices, ref_res, check_negative=False): data = np.asarray(data, dtype="float32") indices = np.asarray(indices, dtype="int32") ref_res = np.asarray(ref_res) + if check_negative: + axis_size = data.shape[axis] + indices = indices - axis_size d = relay.var("x", relay.TensorType(data.shape, "float32")) i = relay.var("y", relay.TensorType(indices.shape, "int32")) - z = relay.gather(d, axis, i) + z = relay.gather(d, axis, i, support_negative_indices=True) func = relay.Function([d, i], z) @@ -1258,12 +1258,15 @@ def verify_gather(data, axis, indices, ref_res): verify_gather(data, axis, indices, ref_res) + # Verify negative indices also work properly, we should not change results + verify_gather(data, axis, indices, ref_res) + def test_gather_nd(target, dev, executor_kind): def verify_gather_nd(xshape, yshape, y_data, batch_dims=0): x = relay.var("x", relay.TensorType(xshape, "float32")) y = relay.var("y", relay.TensorType(yshape, "int32")) - z = relay.gather_nd(x, y, batch_dims) + z = relay.gather_nd(x, y, batch_dims, support_negative_indices=True) func = relay.Function([x, y], z) @@ -1309,6 +1312,11 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0): verify_gather_nd((3, 2, 2, 3, 4), (2, 3, 2, 2), None, 2) verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2, 3), None, 2) + # Verify negative indices work (copy of tests above with some indices replaced with negatives) + verify_gather_nd((2, 2, 2), (1, 2), [[-1, -2]], 1) + verify_gather_nd((2, 2, 2), (1, 2, 1), [[[-1], [-2]]], 1) + verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, -1], [2, 0, -2]], [[0, 0, 0], [1, 1, 1]]]) + def _verify_infiniteness_ops(relay_op, ref_op): for dtype in ["float32", "float16", "float16", "int32", "int16"]: From 1da783d17e3c4acfc868c71163168e3e4f9e90ae Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 9 Sep 2021 16:44:16 -0700 Subject: [PATCH 04/18] remove unneeded gather --- python/tvm/relay/frontend/onnx.py | 7 ++++--- tests/python/frontend/onnx/test_forward.py | 4 ---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 2fb310f68215..4fe269c19aa4 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1544,8 +1544,7 @@ def _impl_v1(cls, inputs, attr, params): data = inputs[0] indices = inputs[1] axis = attr.get("axis", 0) - indices = normalize_gather_indices(data, indices, axis) - return _op.gather(data, axis, indices) + return _op.gather(data, axis, indices, support_negative_indices=True) class GatherND(OnnxOpConverter): @@ -3560,7 +3559,9 @@ def _impl_v13(cls, inputs, attr, params): expanded_target_tensor = relay.expand_dims(target_tensor, 0) expanded_target_tensor = relay.nn.batch_flatten(expanded_target_tensor) - flattened_weights = relay.gather_nd(weight_tensor, expanded_target_tensor) + flattened_weights = relay.gather_nd( + weight_tensor, expanded_target_tensor, support_negative_indices=True + ) select_weights = relay.reshape_like(flattened_weights, loss) loss *= select_weights diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 81d478c54fca..5cbdc7f025c4 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4789,10 +4789,6 @@ def verify_eyelike(indata): "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded", "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded", # These nllloss tests are flaky and sometimes gives NaNs - # Investigate it here: https://github.com/apache/tvm/issues/8918 - "test_nllloss_NCd1d2d3_none_no_weight_negative_ii", - # Investigate it here: https://github.com/apache/tvm/issues/8964 - "test_nllloss_NCd1d2d3_sum_weight_high_ii", "test_qlinearmatmul_2D", "test_qlinearmatmul_3D", "test_range_float_type_positive_delta_expanded", From bb8594b2188238c7ccb4917afc10a9727d0fe347 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 9 Sep 2021 16:52:08 -0700 Subject: [PATCH 05/18] missing gather nd change --- python/tvm/relay/frontend/onnx.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4fe269c19aa4..7c29ebdfbaf1 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1556,7 +1556,13 @@ def _impl_common(cls, data, indices, batch_dims=0): indices_shape = infer_shape(indices) indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) index_rank = indices_shape[-1] - return _op.gather_nd(data, indices, batch_dims, index_rank) + return _op.gather_nd( + data, + indices, + batch_dims=batch_dims, + support_negative_indices=True, + index_rank=index_rank, + ) @classmethod def _impl_v1(cls, inputs, attr, params): From a611b6417c778cc29b86376f3f296773034c50e3 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Thu, 9 Sep 2021 23:21:04 -0700 Subject: [PATCH 06/18] update tests --- tests/python/frontend/onnx/test_forward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 5cbdc7f025c4..f38f553ba7c1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4788,7 +4788,6 @@ def verify_eyelike(indata): "test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded", "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded", "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded", - # These nllloss tests are flaky and sometimes gives NaNs "test_qlinearmatmul_2D", "test_qlinearmatmul_3D", "test_range_float_type_positive_delta_expanded", From e81dfc93b51bffaf54a2015f33ff9b14a9a1f936 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Fri, 10 Sep 2021 01:01:13 -0700 Subject: [PATCH 07/18] proper tensor comparison --- src/te/tensor.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 6db3be040383..c1ba9823c55c 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -54,11 +54,13 @@ PrimExpr Tensor::operator()(Array indices, bool support_negative_indic } if (support_negative_indices) { - for (int i = 0; i < shape.size(); i++) { - PrimExpr new_index = indexmod(indices[i], shape[i]); + for (size_t i = 0; i < shape.size(); i++) { + PrimExpr new_index = if_then_else(indices[i] < make_const(indices[i]->dtype, 0), + indices[i] + shape[i], indices[i]); indices.Set(i, new_index); } } + return ProducerLoad((*this), indices); } From 260ba96bb8cf4b0dcad9eb1601e77d4e64ecd535 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Fri, 10 Sep 2021 01:11:18 -0700 Subject: [PATCH 08/18] blacking --- tests/python/frontend/onnx/test_forward.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f38f553ba7c1..c7c64031e2fd 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -235,8 +235,7 @@ def verify_with_ort( def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): - from onnxruntime.quantization import (CalibrationDataReader, QuantType, - quantize_static) + from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize_static input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] From 10c2913a447aff7fd6927d97d5ce0d35ca3dc44b Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Fri, 10 Sep 2021 11:19:38 -0700 Subject: [PATCH 09/18] lint --- include/tvm/te/tensor.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 4054c51bf99d..32e87954fffd 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -129,12 +129,14 @@ class Tensor : public DataProducer { /*! * \brief Take elements from the tensor * \param indices the indices. + * \param support_negative_indices whether we support negative indexing which is slightly slower. * \return the result expression representing tensor read. */ TVM_DLL PrimExpr operator()(Array indices, bool support_negative_indices = false) const; /*! * \brief Take elements from the tensor * \param indices the indices. + * \param support_negative_indices whether we support negative indexing which is slightly slower. * \return the result expression representing tensor read. */ TVM_DLL PrimExpr operator()(Array indices, bool support_negative_indices = false) const; From fef2cd67f4026c84856de65e55cc4cf4e29e43fe Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Fri, 10 Sep 2021 14:28:12 -0700 Subject: [PATCH 10/18] fix error --- python/tvm/relay/frontend/onnx.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7c29ebdfbaf1..f9756eaf02af 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3577,7 +3577,9 @@ def _impl_v13(cls, inputs, attr, params): target_tensor, relay.const(ignore_index, dtype=target_tensor.type_annotation.dtype) ) mask_tensor = relay.const(1, dtype="int8") - relay.cast(mask_tensor, "int8") - loss *= relay.cast_like(mask_tensor, loss) + loss = relay.where( + mask_tensor, loss, relay.const(0, infer_type(loss).checked_type.dtype) + ) # This is not explained super clearly in the onnx spec, but masked values don't # contribute toward the final value in reduction From 7949a3946471489d1701bb54ec17e73f30dff43f Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Fri, 10 Sep 2021 14:32:47 -0700 Subject: [PATCH 11/18] turn on test --- tests/python/relay/test_op_level3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 4b4f10e0c569..566683bf1202 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1259,7 +1259,7 @@ def verify_gather(data, axis, indices, ref_res, check_negative=False): verify_gather(data, axis, indices, ref_res) # Verify negative indices also work properly, we should not change results - verify_gather(data, axis, indices, ref_res) + verify_gather(data, axis, indices, ref_res, check_negative=True) def test_gather_nd(target, dev, executor_kind): From 02f1870d7f2e274cfd7e04678691da019ff201f0 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 13 Sep 2021 10:39:11 -0700 Subject: [PATCH 12/18] missing test case --- tests/python/relay/test_any.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 497177d241f0..b8c263eaab53 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -24,8 +24,8 @@ from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type -from utils.assert_diagnostic import DiagnosticTesting from utils import ref_funcs +from utils.assert_diagnostic import DiagnosticTesting def int32(val): @@ -2046,7 +2046,7 @@ def test_gather_nd(): def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, batch_dims=0): x = relay.var("x", relay.TensorType(data_shape, "float32")) y = relay.var("y", relay.TensorType(indices_shape, "int32")) - z = relay.gather_nd(x, y, batch_dims, indices_shape[0]) + z = relay.gather_nd(x, y, batch_dims=batch_dims, index_rank=indices_shape[0]) mod = tvm.IRModule() mod["main"] = relay.Function([x, y], z) From a184f7c288b8a30ac19c6e4dcdb83b61490d6b88 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Wed, 15 Sep 2021 11:33:31 -0700 Subject: [PATCH 13/18] revert changes --- include/tvm/relay/attrs/transform.h | 8 -------- include/tvm/te/tensor.h | 6 ++---- include/tvm/topi/transform.h | 19 ++++++------------- python/tvm/relay/frontend/onnx.py | 9 +++------ python/tvm/relay/op/transform.py | 12 ++++-------- src/relay/op/tensor/transform.cc | 10 +++------- src/te/tensor.cc | 14 +++----------- tests/python/relay/test_op_level3.py | 17 +++-------------- 8 files changed, 24 insertions(+), 71 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 4d6386b59e9d..a8317e1e51ad 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -136,22 +136,17 @@ struct ScatterNDAttrs : public tvm::AttrsNode { struct GatherAttrs : public tvm::AttrsNode { Integer axis; - Bool support_negative_indices = Bool(false); TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherAttrs") { TVM_ATTR_FIELD(axis) .set_default(NullValue()) .describe("The axis over which to select values."); - TVM_ATTR_FIELD(support_negative_indices) - .set_default(Bool(false)) - .describe("If negative indices are supported."); } }; struct GatherNDAttrs : public tvm::AttrsNode { Integer batch_dims; Optional index_rank; - Bool support_negative_indices = Bool(false); TVM_DECLARE_ATTRS(GatherNDAttrs, "relay.attrs.GatherNDAttrs") { TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions."); @@ -160,9 +155,6 @@ struct GatherNDAttrs : public tvm::AttrsNode { .describe( "The size of an indexing tuple, which is a fixed value. Only needed when the number of " "indexting tuples is dynamic."); - TVM_ATTR_FIELD(support_negative_indices) - .set_default(Bool(false)) - .describe("If negative indices are supported."); } }; diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 32e87954fffd..85677a726574 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -129,17 +129,15 @@ class Tensor : public DataProducer { /*! * \brief Take elements from the tensor * \param indices the indices. - * \param support_negative_indices whether we support negative indexing which is slightly slower. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr operator()(Array indices, bool support_negative_indices = false) const; + TVM_DLL PrimExpr operator()(Array indices) const; /*! * \brief Take elements from the tensor * \param indices the indices. - * \param support_negative_indices whether we support negative indexing which is slightly slower. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr operator()(Array indices, bool support_negative_indices = false) const; + TVM_DLL PrimExpr operator()(Array indices) const; /*! * \brief data structure to represent a slice that fixes first k coordinates. * This is used to enable syntax sugar of Tensor[x][y][z] to get the element. diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index ab6f7e83b4b9..f9bffb6e462a 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1219,13 +1219,11 @@ inline Tensor dyn_tile(const Tensor& x, Array new_shape, size_t rdim, * \param indices The indices of values to gather. * \param name The name of the operation. * \param tag The tag to mark the operation. - * \param support_negative_indices If negative indices are supported * * \return A Tensor whose op member is the gather operation */ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, - bool support_negative_indices = false, std::string name = "T_gather", - std::string tag = kInjective) { + std::string name = "T_gather", std::string tag = kInjective) { size_t ndim_d = data->shape.size(); size_t ndim_i = indices->shape.size(); ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar."; @@ -1244,8 +1242,6 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, out_shape.push_back(indices->shape[i]); } - PrimExpr axis_size = data->shape[axis]; - return compute( out_shape, [&](const Array& out_index) { @@ -1256,13 +1252,12 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, Array real_indices; for (size_t i = 0; i < ndim_i; ++i) { if (i == static_cast(axis)) { - PrimExpr index = indices(indices_position); - real_indices.push_back(index); + real_indices.push_back(indices(indices_position)); } else { real_indices.push_back(indices_position[i]); } } - return data(real_indices, support_negative_indices); + return data(real_indices); }, name, tag); } @@ -1275,13 +1270,11 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, * \param batch_dims The number of batch dimensions. * \param name The name of the operation. * \param tag The tag to mark the operation. - * \param support_negative_indices If negative indices are supported * * \return A Tensor whose op member is the gather_nd operation */ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0, - bool support_negative_indices = false, std::string name = "T_gather_nd", - std::string tag = kInjective) { + std::string name = "T_gather_nd", std::string tag = kInjective) { size_t ndim_d = data->shape.size(); size_t ndim_i = indices->shape.size(); ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions"; @@ -1317,12 +1310,12 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim real_indices.push_back(index); } if (real_indices.size() == ndim_d) { - return data(real_indices, support_negative_indices); + return data(real_indices); } for (size_t i = ndim_i - 1; i < out_index.size(); ++i) { real_indices.push_back(out_index[i]); } - return data(real_indices, support_negative_indices); + return data(real_indices); }, name, tag); } diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f9756eaf02af..55da4063123d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1544,7 +1544,8 @@ def _impl_v1(cls, inputs, attr, params): data = inputs[0] indices = inputs[1] axis = attr.get("axis", 0) - return _op.gather(data, axis, indices, support_negative_indices=True) + indices = normalize_gather_indices(data, indices, axis) + return _op.gather(data, axis, indices) class GatherND(OnnxOpConverter): @@ -1560,7 +1561,6 @@ def _impl_common(cls, data, indices, batch_dims=0): data, indices, batch_dims=batch_dims, - support_negative_indices=True, index_rank=index_rank, ) @@ -3559,15 +3559,12 @@ def _impl_v13(cls, inputs, attr, params): input_tensor, axis=1, indices=relay.expand_dims(target_tensor, 1), - support_negative_indices=True, ) loss = relay.squeeze(loss, axis=[1]) expanded_target_tensor = relay.expand_dims(target_tensor, 0) expanded_target_tensor = relay.nn.batch_flatten(expanded_target_tensor) - flattened_weights = relay.gather_nd( - weight_tensor, expanded_target_tensor, support_negative_indices=True - ) + flattened_weights = relay.gather_nd(weight_tensor, expanded_target_tensor) select_weights = relay.reshape_like(flattened_weights, loss) loss *= select_weights diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index cd1c8e613c11..ca41203625e6 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1046,7 +1046,7 @@ def reverse_reshape(data, newshape): return _make.contrib_reverse_reshape(data, list(newshape)) -def gather(data, axis, indices, support_negative_indices=False): +def gather(data, axis, indices): """Gather values along given axis from given indices. E.g. for a 3D tensor, output is computed as: @@ -1071,10 +1071,6 @@ def gather(data, axis, indices, support_negative_indices=False): indices: relay.Expr The indices of values to gather. - support_negative_indices: bool - If True, support indices being negative. This is slower than supporting only - positive indices. - Examples -------- .. code-block:: python @@ -1084,10 +1080,10 @@ def gather(data, axis, indices, support_negative_indices=False): indices = [[0, 0], [1, 0]] relay.gather(data, axis, indices) = [[1, 1], [4, 3]] """ - return _make.gather(data, axis, indices, support_negative_indices) + return _make.gather(data, axis, indices) -def gather_nd(data, indices, batch_dims=0, support_negative_indices=False, index_rank=None): +def gather_nd(data, indices, batch_dims=0, index_rank=None): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1131,7 +1127,7 @@ def gather_nd(data, indices, batch_dims=0, support_negative_indices=False, index indices = [[1, 0]] relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]] """ - return _make.gather_nd(data, indices, batch_dims, support_negative_indices, index_rank) + return _make.gather_nd(data, indices, batch_dims, index_rank) def sequence_mask(data, valid_length, mask_value=0, axis=0): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e92843816cee..3781107eeee1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3274,13 +3274,12 @@ bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs, Array GatherCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); - return {topi::gather(inputs[0], param->axis, inputs[1], param->support_negative_indices)}; + return {topi::gather(inputs[0], param->axis, inputs[1])}; } -Expr MakeGather(Expr data, Integer axis, Expr indices, Bool support_negative_indices) { +Expr MakeGather(Expr data, Integer axis, Expr indices) { auto attrs = make_object(); attrs->axis = std::move(axis); - attrs->support_negative_indices = std::move(support_negative_indices); static const Op& op = Op::Get("gather"); return Call(op, {data, indices}, Attrs(attrs), {}); } @@ -3354,18 +3353,15 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i const Type& out_type) { const auto* param = attrs.as(); ICHECK(param); - return { - topi::gather_nd(inputs[0], inputs[1], param->batch_dims, param->support_negative_indices)}; + return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; } Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, - Bool support_negative_indices = Bool(0), Optional index_rank = NullValue()) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); attrs->batch_dims = batch_dims; attrs->index_rank = index_rank; - attrs->support_negative_indices = support_negative_indices; return Call(op, {data, indices}, Attrs(attrs)); } diff --git a/src/te/tensor.cc b/src/te/tensor.cc index c1ba9823c55c..f0f2bb4c9f1f 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -39,12 +39,12 @@ IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name) Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor -PrimExpr Tensor::operator()(Array indices, bool support_negative_indices) const { +PrimExpr Tensor::operator()(Array indices) const { Array arr(indices.begin(), indices.end()); - return operator()(arr, support_negative_indices); + return operator()(arr); } -PrimExpr Tensor::operator()(Array indices, bool support_negative_indices) const { +PrimExpr Tensor::operator()(Array indices) const { Array shape = (*this)->shape; if (shape.size() != 0) { @@ -53,14 +53,6 @@ PrimExpr Tensor::operator()(Array indices, bool support_negative_indic << "ndim = " << ndim() << ", indices.size=" << indices.size(); } - if (support_negative_indices) { - for (size_t i = 0; i < shape.size(); i++) { - PrimExpr new_index = if_then_else(indices[i] < make_const(indices[i]->dtype, 0), - indices[i] + shape[i], indices[i]); - indices.Set(i, new_index); - } - } - return ProducerLoad((*this), indices); } diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 566683bf1202..71902fff9545 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1238,16 +1238,13 @@ def test_scatter_add(self, target, dev, ref_data, dshape, ishape, axis, dtype): ], ) def test_gather(target, dev, executor_kind, data, axis, indices, ref_res): - def verify_gather(data, axis, indices, ref_res, check_negative=False): + def verify_gather(data, axis, indices, ref_res): data = np.asarray(data, dtype="float32") indices = np.asarray(indices, dtype="int32") ref_res = np.asarray(ref_res) - if check_negative: - axis_size = data.shape[axis] - indices = indices - axis_size d = relay.var("x", relay.TensorType(data.shape, "float32")) i = relay.var("y", relay.TensorType(indices.shape, "int32")) - z = relay.gather(d, axis, i, support_negative_indices=True) + z = relay.gather(d, axis, i) func = relay.Function([d, i], z) @@ -1258,15 +1255,12 @@ def verify_gather(data, axis, indices, ref_res, check_negative=False): verify_gather(data, axis, indices, ref_res) - # Verify negative indices also work properly, we should not change results - verify_gather(data, axis, indices, ref_res, check_negative=True) - def test_gather_nd(target, dev, executor_kind): def verify_gather_nd(xshape, yshape, y_data, batch_dims=0): x = relay.var("x", relay.TensorType(xshape, "float32")) y = relay.var("y", relay.TensorType(yshape, "int32")) - z = relay.gather_nd(x, y, batch_dims, support_negative_indices=True) + z = relay.gather_nd(x, y, batch_dims) func = relay.Function([x, y], z) @@ -1312,11 +1306,6 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0): verify_gather_nd((3, 2, 2, 3, 4), (2, 3, 2, 2), None, 2) verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2, 3), None, 2) - # Verify negative indices work (copy of tests above with some indices replaced with negatives) - verify_gather_nd((2, 2, 2), (1, 2), [[-1, -2]], 1) - verify_gather_nd((2, 2, 2), (1, 2, 1), [[[-1], [-2]]], 1) - verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, -1], [2, 0, -2]], [[0, 0, 0], [1, 1, 1]]]) - def _verify_infiniteness_ops(relay_op, ref_op): for dtype in ["float32", "float16", "float16", "int32", "int16"]: From 56650da45a767461430200c679a6f89b92a3c5ba Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Wed, 15 Sep 2021 12:01:20 -0700 Subject: [PATCH 14/18] add normalize_gather_indices --- python/tvm/relay/frontend/onnx.py | 15 +++++++++++---- python/tvm/relay/op/transform.py | 4 ---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 55da4063123d..5ca2736b3aec 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3546,6 +3546,11 @@ def _impl_v13(cls, inputs, attr, params): ) input_tensor, target_tensor = inputs[0], inputs[1] + + # Convert negative indices --> positive indices for gather ops, note we have to + # use the original target tensor to interact with ignore_index to have proper behavior. + normalized_target_tensor = normalize_gather_indices(input_tensor, target_tensor, 1) + if len(inputs) == 3: weight_tensor = inputs[2] else: @@ -3558,13 +3563,15 @@ def _impl_v13(cls, inputs, attr, params): loss = -relay.gather( input_tensor, axis=1, - indices=relay.expand_dims(target_tensor, 1), + indices=relay.expand_dims(normalized_target_tensor, 1), ) loss = relay.squeeze(loss, axis=[1]) - expanded_target_tensor = relay.expand_dims(target_tensor, 0) - expanded_target_tensor = relay.nn.batch_flatten(expanded_target_tensor) - flattened_weights = relay.gather_nd(weight_tensor, expanded_target_tensor) + expanded_normalized_target_tensor = relay.expand_dims(normalized_target_tensor, 0) + expanded_normalized_target_tensor = relay.nn.batch_flatten( + expanded_normalized_target_tensor + ) + flattened_weights = relay.gather_nd(weight_tensor, expanded_normalized_target_tensor) select_weights = relay.reshape_like(flattened_weights, loss) loss *= select_weights diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index ca41203625e6..2c299022bd6e 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1102,10 +1102,6 @@ def gather_nd(data, indices, batch_dims=0, index_rank=None): The size of an indexing tuple, which is a fixed value and the same as indices.shape[0] Only needed when other dimensions of indices are dynamic. - support_negative_indices: bool - If True, support indices being negative. This is slower than supporting only - positive indices. - Returns ------- ret : relay.Expr From 73d3d55ef04768df5b17772259b4e5a48e3a7d54 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Wed, 15 Sep 2021 12:05:08 -0700 Subject: [PATCH 15/18] undo change --- include/tvm/topi/transform.h | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index f9bffb6e462a..659bd7e876b8 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1301,13 +1301,11 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim real_indices.push_back(out_index[i]); } for (size_t i = 0; i < indices_dim0; ++i) { - indices_position.Set(0, make_const(DataType::Int(32), i)); - PrimExpr index = indices(indices_position); - - if (!indices->dtype.is_int()) { - index = tvm::cast(tvm::DataType::Int(32), index); + if (indices->dtype.is_int()) { + real_indices.push_back(indices(indices_position)); + } else { + real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position))); } - real_indices.push_back(index); } if (real_indices.size() == ndim_d) { return data(real_indices); From d7e24f8ef84560df9d4466d373789754f3bed737 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Wed, 15 Sep 2021 12:05:41 -0700 Subject: [PATCH 16/18] update --- include/tvm/topi/transform.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 659bd7e876b8..8d1a49a4cc5f 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1301,6 +1301,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim real_indices.push_back(out_index[i]); } for (size_t i = 0; i < indices_dim0; ++i) { + indices_position.Set(0, make_const(DataType::Int(32), i)); if (indices->dtype.is_int()) { real_indices.push_back(indices(indices_position)); } else { From 5ff388ba49e8038329df4e1895db49a81f1b8da7 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Wed, 15 Sep 2021 12:07:46 -0700 Subject: [PATCH 17/18] more removing diffs --- src/te/tensor.cc | 10 +++------- tests/python/relay/test_op_level3.py | 3 +++ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/te/tensor.cc b/src/te/tensor.cc index f0f2bb4c9f1f..5dd6736740e6 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -45,14 +45,10 @@ PrimExpr Tensor::operator()(Array indices) const { } PrimExpr Tensor::operator()(Array indices) const { - Array shape = (*this)->shape; - - if (shape.size() != 0) { - ICHECK_EQ(shape.size(), indices.size()) - << "Tensor dimension mismatch in read " - << "ndim = " << ndim() << ", indices.size=" << indices.size(); + if (ndim() != 0) { + ICHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read " + << "ndim = " << ndim() << ", indices.size=" << indices.size(); } - return ProducerLoad((*this), indices); } diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 71902fff9545..eaddd33678df 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -21,8 +21,10 @@ import numpy as np import pytest + import tvm import tvm.testing + from tvm import relay, te from tvm.error import TVMError from tvm.relay import create_executor, transform @@ -30,6 +32,7 @@ from utils import ref_funcs + executor_kind = tvm.testing.parameter("graph", "debug") From dbbd42e0234c64baa1e88eddfa33d0022ee4dfdf Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Wed, 15 Sep 2021 12:08:20 -0700 Subject: [PATCH 18/18] more undoing --- src/te/tensor.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 5dd6736740e6..b48f39a38627 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -49,6 +49,7 @@ PrimExpr Tensor::operator()(Array indices) const { ICHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read " << "ndim = " << ndim() << ", indices.size=" << indices.size(); } + return ProducerLoad((*this), indices); }