From dff202e975d1bf9089bd1181c6a06a90ee4926c5 Mon Sep 17 00:00:00 2001 From: zxy844288792 Date: Wed, 14 Apr 2021 18:01:18 +0000 Subject: [PATCH 1/5] add batch_dim support --- include/tvm/relay/attrs/transform.h | 4 + include/tvm/topi/transform.h | 119 +++++++++++++----- python/tvm/relay/frontend/tensorflow.py | 13 +- python/tvm/relay/op/_transform.py | 15 ++- python/tvm/relay/op/transform.py | 7 +- python/tvm/topi/transform.py | 9 +- src/relay/op/make_op.h | 2 +- src/relay/op/tensor/transform.cc | 23 +++- src/topi/transform.cc | 14 ++- .../frontend/tensorflow/test_forward.py | 38 ++++-- 10 files changed, 172 insertions(+), 72 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 113c8209fe6a..723f9ecdab90 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -145,10 +145,14 @@ struct GatherAttrs : public tvm::AttrsNode { }; struct TakeAttrs : public tvm::AttrsNode { + Integer batch_dims; Integer axis; std::string mode; TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { + TVM_ATTR_FIELD(batch_dims) + .set_default(0) + .describe("The batch_dims over which to select values."); TVM_ATTR_FIELD(axis) .set_default(NullValue()) .describe("The axis over which to select values."); diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 114b8f617387..7e83387016db 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -770,8 +770,9 @@ inline Array split_sections(const Tensor& x, int num_sections, int axis, * * \return A Tensor whose op member is the take operation */ -inline Tensor take(const Tensor& a, const Tensor& indices, std::string mode = "clip", - std::string name = "T_take", std::string tag = kInjective) { +inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, + std::string mode = "clip", std::string name = "T_take", + std::string tag = kInjective) { Array a_shape = a->shape; Array out_shape = indices->shape; PrimExpr a_size = 1; @@ -846,6 +847,7 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub * * \param a The source array. * \param indices The indices of the values to extract. + * \param batch_dims The number of batch dimensions. By default is 0. * \param axis The axis over which to select values. By default, * the flattened input array is used. * \param mode The mode for handling out of bound indices. @@ -854,46 +856,99 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub * * \return A Tensor whose op member is the take operation */ -inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string mode = "clip", - std::string name = "T_take", std::string tag = kInjective) { +inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int axis, + std::string mode = "clip", std::string name = "T_take", + std::string tag = kInjective) { if (axis < 0) { axis += static_cast(a->shape.size()); } ICHECK_GE(axis, 0) << "axis out of bounds"; ICHECK_LT(axis, a->shape.size()) << "axis out of bounds"; auto axis_dim = a->shape[axis]; - int indices_len = static_cast(indices->shape.size()); - Array out_shape; - for (size_t i = 0; i < a->shape.size(); ++i) { - if (axis == static_cast(i)) { - for (size_t j = 0; j < indices->shape.size(); ++j) { - out_shape.push_back(indices->shape[j]); - } - } else { - out_shape.push_back(a->shape[i]); + + int batch_dims_ = batch_dims; + if (batch_dims_ != 0) { + ICHECK_GE(batch_dims_, -static_cast(indices->shape.size())) << "batch_dims out of bounds"; + ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of bounds"; + + if (batch_dims_ < 0) { + batch_dims_ = indices->shape.size() + batch_dims_; } + + ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds"; + ICHECK_GE(axis, batch_dims_) << "batch_dims must be less than or equal to axis"; + for (int i = 0; i < batch_dims_; ++i) { + auto addr1 = a->shape[i]; + auto addr2 = indices->shape[i]; + auto v1 = static_cast(&addr1)->get()->value; + auto v2 = static_cast(&addr2)->get()->value; + ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]"; + } + } + + // The result shape is a.shape[:axis] + indices.shape[batch_dims:] + + // a.shape[axis + 1:]. + + Array out_shape; + for (int i = 0; i < batch_dims_; ++i) { + out_shape.push_back(a->shape[i]); + } + for (int i = batch_dims_; i < axis; ++i) { + out_shape.push_back(a->shape[i]); + } + for (size_t i = static_cast(batch_dims_); i < indices->shape.size(); ++i) { + out_shape.push_back(indices->shape[i]); + } + for (size_t i = axis + 1; i < a->shape.size(); ++i) { + out_shape.push_back(a->shape[i]); } + if (mode == "clip") { - return compute( - out_shape, - [&](const Array& out_index) { - Array indices_position; - for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { - indices_position.push_back(out_index[j]); - } - Array real_indices; - for (size_t j = 0; j < static_cast(axis); ++j) { - real_indices.push_back(out_index[j]); - } - auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1); - real_indices.push_back(idx); - for (size_t j = axis + indices_len; j < out_index.size(); ++j) { - real_indices.push_back(out_index[j]); - } - return a(real_indices); - }, - name, tag); + if (batch_dims_ == 0) { + return compute( + out_shape, + [&](const Array& out_index) { + Array indices_position; + for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { + indices_position.push_back(out_index[j]); + } + Array real_indices; + for (size_t j = 0; j < static_cast(axis); ++j) { + real_indices.push_back(out_index[j]); + } + auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1); + real_indices.push_back(idx); + for (size_t j = axis + indices_len; j < out_index.size(); ++j) { + real_indices.push_back(out_index[j]); + } + return a(real_indices); + }, + name, tag); + } else { + return compute( + out_shape, + [&](const Array& out_index) { + Array indices_position; + for (size_t j = 0; j < static_cast(batch_dims_); ++j) { + indices_position.push_back(out_index[j]); + } + for (size_t j = axis; j < static_cast(axis + indices_len - batch_dims_); ++j) { + indices_position.push_back(out_index[j]); + } + Array real_indices; + for (size_t j = 0; j < static_cast(axis); ++j) { + real_indices.push_back(out_index[j]); + } + auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1); + real_indices.push_back(idx); + for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) { + real_indices.push_back(out_index[j]); + } + return a(real_indices); + }, + name, tag); + } } else if (mode == "fast") { LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " "Make sure input indices are in bound"; diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index f566a3fb92d0..4bd332fa0159 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2002,14 +2002,19 @@ def _impl(inputs, attr, params, mod): axis = _get_num_param(params, inputs.pop(2)) else: axis = 0 + batch_dims = 0 if int(attr.get("batch_dims", 0)) != 0: - raise tvm.error.OpAttributeUnImplemented("Attribute batch_dims is not supported") + batch_dims = int(attr.get("batch_dims", 0)) new_input = inputs[0:2] - return AttrCvt( + op_ = AttrCvt( op_name="take", - extras={"axis": tvm.tir.const(axis, "int32")}, - ignores=["Tindices", "Tparams", "validate_indices", "Taxis", "_class", "batch_dims"], + extras={ + "axis": tvm.tir.const(axis, "int32"), + "batch_dims": tvm.tir.const(batch_dims, "int32"), + }, + ignores=["Tindices", "Tparams", "validate_indices", "Taxis", "_class"], )(new_input, attr) + return op_ return _impl diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 2920c9955b9b..76adee477a1a 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -390,7 +390,7 @@ def _take_no_axis_shape_func(indices_shape, out_ndim): @script -def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim): +def _take_with_axis_shape_func(data_shape, indices_shape, axis, batch_dims, out_ndim): out = output_tensor((out_ndim,), "int64") for i in const_range(axis): out[i] = data_shape[i] @@ -399,10 +399,10 @@ def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim): for i in const_range(axis + 1, len(data_shape)): out[i - 1] = data_shape[i] else: - for i in const_range(len(indices_shape)): - out[axis + i] = indices_shape[i] + for i in const_range(len(indices_shape) - batch_dims): + out[axis + i] = indices_shape[i + batch_dims] for i in const_range(axis + 1, len(data_shape)): - out[len(indices_shape) + i - 1] = data_shape[i] + out[len(indices_shape) + i - 1 - batch_dims] = data_shape[i] return out @@ -414,11 +414,16 @@ def take_shape_func(attrs, inputs, out_ndims): if attrs.axis is None: return [_take_no_axis_shape_func(inputs[1], out_ndims[0])] axis = get_const_int(attrs.axis) + batch_dims = get_const_int(attrs.batch_dims) data_ndim = int(inputs[0].shape[0]) + if inputs[1].shape: + indicies_ndim = int(inputs[1].shape[0]) if axis < 0: axis += data_ndim assert 0 <= axis < data_ndim - return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])] + if batch_dims < 0: + batch_dims += indicies_ndim + return [_take_with_axis_shape_func(*inputs, convert(axis), convert(batch_dims), out_ndims[0])] @_reg.register_legalize("take") diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index df2686196151..4e406e81ef68 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -388,7 +388,7 @@ def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_e return _make.reshape_like(data, shape_like, lhs_begin, lhs_end, rhs_begin, rhs_end) -def take(data, indices, axis=None, mode="clip"): +def take(data, indices, axis=None, batch_dims=0, mode="clip"): """Take elements from an array along an axis. Parameters @@ -403,6 +403,9 @@ def take(data, indices, axis=None, mode="clip"): The axis over which to select values. By default, the flattened input array is used. + batch_dims : int + The number of batch dimensions. By default is 0. + mode : str, optional Specifies how out-of-bound indices will behave [clip, wrap, fast]. clip: clip to the range (default). @@ -414,7 +417,7 @@ def take(data, indices, axis=None, mode="clip"): ret : relay.Expr The computed result. """ - return _make.take(data, indices, axis, mode) + return _make.take(data, indices, batch_dims, axis, mode) def full(fill_value, shape=(), dtype=""): diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 6ddbc73e4666..df30ff775f60 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -396,7 +396,7 @@ def split(ary, indices_or_sections, axis=0): return cpp.split(ary, indices_or_sections, axis) -def take(a, indices, axis=None, mode="clip"): +def take(a, indices, axis=None, batch_dims=0, mode="clip"): """Take elements from an array along an axis. Parameters @@ -411,6 +411,9 @@ def take(a, indices, axis=None, mode="clip"): The axis over which to select values. By default, the flattened input array is used. + batch_dims : int + The number of batch dimensions. By default is 0. + mode : str, optional Specifies how out-of-bound indices will behave. clip - clip to the range (default) @@ -422,8 +425,8 @@ def take(a, indices, axis=None, mode="clip"): ret : tvm.te.Tensor """ if axis is None: - return cpp.take(a, indices, mode) - return cpp.take(a, indices, int(axis), mode) + return cpp.take(a, indices, int(batch_dims), mode) + return cpp.take(a, indices, int(batch_dims), int(axis), mode) @tvm.target.generic_func diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index cc1ff44952ef..bbfef5883e3d 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -107,7 +107,7 @@ Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype); Expr MakeShapeOf(Expr data, DataType dtype); -Expr MakeTake(Expr data, Expr indices, Integer axis, String mode); +Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e937cb0c7b1f..4b12c718caea 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1200,15 +1200,24 @@ bool TakeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto ndim_data = static_cast(data->shape.size()); const auto ndim_indices = static_cast(indices->shape.size()); int axis = static_cast(param->axis->value); + int batch_dims = static_cast(param->batch_dims->value); if (axis < 0) axis += ndim_data; + if (batch_dims < 0) axis += ndim_indices; ICHECK_LE(axis, ndim_data) << "axis should be with in data shape" << ", but got = " << axis; + ICHECK_LE(batch_dims, ndim_indices) << "batch_dims should be with in indices shape" + << ", but got = " << batch_dims; + ICHECK_LE(batch_dims, axis) << "batch_dims should be less than or equal to axis" + << ", but got = " << batch_dims; - oshape.reserve(ndim_data - 1 + ndim_indices); - for (int i = 0; i < axis; ++i) { + oshape.reserve(ndim_data - 1 + ndim_indices - batch_dims); + for (int i = 0; i < batch_dims; ++i) { + oshape.emplace_back(data->shape[i]); + } + for (int i = batch_dims; i < axis; ++i) { oshape.emplace_back(data->shape[i]); } - for (int i = 0; i < ndim_indices; ++i) { + for (int i = batch_dims; i < ndim_indices; ++i) { oshape.emplace_back(indices->shape[i]); } for (int i = axis + 1; i < ndim_data; ++i) { @@ -1224,14 +1233,16 @@ Array TakeCompute(const Attrs& attrs, const Array& input const auto* param = attrs.as(); ICHECK(param != nullptr); if (!param->axis.defined()) { - return Array{topi::take(inputs[0], inputs[1], param->mode)}; + return Array{topi::take(inputs[0], inputs[1], param->batch_dims, param->mode)}; } else { - return Array{topi::take(inputs[0], inputs[1], param->axis, param->mode)}; + return Array{ + topi::take(inputs[0], inputs[1], param->batch_dims, param->axis, param->mode)}; } } -Expr MakeTake(Expr data, Expr indices, Integer axis, String mode) { +Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode) { auto attrs = make_object(); + attrs->batch_dims = std::move(batch_dims); attrs->axis = std::move(axis); attrs->mode = std::move(mode); static const Op& op = Op::Get("take"); diff --git a/src/topi/transform.cc b/src/topi/transform.cc index f71fae3c5aaa..50f2208721c1 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -87,13 +87,15 @@ TVM_REGISTER_GLOBAL("topi.layout_transform").set_body([](TVMArgs args, TVMRetVal }); TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) { - if (args.size() == 3) { - std::string mode = args[2]; - *rv = take(args[0], args[1], mode); - } else { - int axis = args[2]; + if (args.size() == 4) { std::string mode = args[3]; - *rv = take(args[0], args[1], axis, mode); + int batch_dims = args[2]; + *rv = take(args[0], args[1], batch_dims, mode); + } else { + int batch_dims = args[2]; + int axis = args[3]; + std::string mode = args[4]; + *rv = take(args[0], args[1], batch_dims, axis, mode); } }); diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index f4e7522b7038..e7b189345c61 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2705,14 +2705,14 @@ def test_forward_truncatemod(): # -------------------------- -def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): +def _test_gather(ip_shape, indice_shape, indice_value, axis, batch_dims, dtype): """ One iteration of a GatherV2 """ tf.reset_default_graph() with tf.Graph().as_default(): in_data = tf.placeholder(dtype, ip_shape, name="in_data") indices = tf.placeholder("int32", indice_shape, name="indices") - out = tf.gather(in_data, indices, axis=axis) + out = tf.gather(in_data, indices, axis=axis, batch_dims=batch_dims) np_data = np.random.uniform(1, 10, size=ip_shape).astype(dtype) def _fill_indices(indice_value): @@ -2724,22 +2724,34 @@ def _fill_indices(indice_value): return indices np_indices = _fill_indices(indice_value) - compare_tf_with_tvm([np_data, np_indices], ["in_data:0", "indices:0"], out.name) def test_forward_gather(): """test Gather/GatherV2 layer""" - _test_gather((4,), (1,), 1, 0, "int32") - _test_gather((4,), (1,), 1, 0, "float32") - _test_gather((1, 4), (1,), [0], 0, "int32") - _test_gather((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 0, "float32") - _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 0, "int32") - _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 1, "int32") - _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 0, "float32") - _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 0, "int32") - _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, "int32") - _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, "float32") + _test_gather((4,), (1,), 1, 0, 1, "int32") + _test_gather((4,), (1,), 1, 0, 0, "float32") + _test_gather((1, 4), (1,), [0], 0, 0, "int32") + _test_gather((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 0, 0, "float32") + _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 0, 0, "int32") + _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 1, 0, "int32") + _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 0, 0, "float32") + _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 0, 0, "int32") + _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 0, "int32") + _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 0, "float32") + _test_gather((2, 2), (2, 2), [[0, 0], [0, 0]], 1, 1, "float32") + _test_gather( + (2, 2, 3, 6), (2, 2, 3), [[[1, 1, 0], [0, 0, 1]], [[0, 1, 0], [1, 0, 1]]], 2, 2, "float32" + ) + _test_gather( + (2, 2, 3, 6), (2, 2, 3), [[[1, 1, 0], [0, 0, 1]], [[0, 1, 0], [1, 0, 1]]], 3, 1, "float32" + ) + _test_gather( + (2, 2, 3, 6), (2, 2, 3), [[[1, 1, 0], [0, 0, 1]], [[0, 1, 0], [1, 0, 1]]], 3, 2, "float32" + ) + _test_gather( + (2, 2, 3, 6), (2, 2, 3), [[[1, 1, 0], [0, 0, 1]], [[0, 1, 0], [1, 0, 1]]], 3, 0, "float32" + ) ####################################################################### From 1423c67498502768eee95d280eb5c4b511e3420d Mon Sep 17 00:00:00 2001 From: zxy844288792 Date: Fri, 30 Apr 2021 21:07:21 +0000 Subject: [PATCH 2/5] fix lint --- include/tvm/topi/transform.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 7e83387016db..cf1950ced628 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -763,6 +763,7 @@ inline Array split_sections(const Tensor& x, int num_sections, int axis, * * \param a The source array. * \param indices The indices of the values to extract. + * \param batch_dims The number of batch dimensions. * \param mode The mode of the operation. * \param name The name of the operation. * \param mode The mode of to handle out of bound indices. From bdad09ef54ac741c623c414016757abc99807c5c Mon Sep 17 00:00:00 2001 From: zxy844288792 Date: Fri, 30 Apr 2021 23:31:59 +0000 Subject: [PATCH 3/5] add check for num of arguments for topi.take --- include/tvm/topi/transform.h | 2 +- src/topi/transform.cc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index cf1950ced628..379234a5c65a 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -878,7 +878,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a } ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds"; - ICHECK_GE(axis, batch_dims_) << "batch_dims must be less than or equal to axis"; + ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis"; for (int i = 0; i < batch_dims_; ++i) { auto addr1 = a->shape[i]; auto addr2 = indices->shape[i]; diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 50f2208721c1..0bce3bbc7f53 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -92,6 +92,7 @@ TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) { int batch_dims = args[2]; *rv = take(args[0], args[1], batch_dims, mode); } else { + ICHECK_EQ(args.size(), 5) << "topi.take expects 4 or 5 arguments"; int batch_dims = args[2]; int axis = args[3]; std::string mode = args[4]; From 292e2a5f6ce5abb7b49ee6984d76fb60e39fd19f Mon Sep 17 00:00:00 2001 From: Xingyu Date: Tue, 4 May 2021 03:04:45 +0000 Subject: [PATCH 4/5] fix gpu test cases --- python/tvm/relay/frontend/mxnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 5415c77097a2..11800eaf3cb3 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -911,7 +911,7 @@ def _mx_take(inputs, attrs): if mode == "raise": raise tvm.error.OpAttributeUnimplemented("take with raise mode is not supported yet") axis = attrs.get_int("axis", 0) - return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode) + return _op.take(inputs[0], inputs[1].astype("int32"), axis=axis, mode=mode) def _mx_gather_nd(inputs, attrs): From 369852babece1da72a6a7ce38bc607ec1028b4d1 Mon Sep 17 00:00:00 2001 From: Xingyu Date: Tue, 4 May 2021 20:57:58 +0000 Subject: [PATCH 5/5] add check for batch_dims in take_grad --- python/tvm/relay/op/_tensor_grad.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 108bef0242fe..d5b891088933 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -711,6 +711,7 @@ def make_scalar_tensor(v): # TODO(@altanh): we currently assume indices are in range data, indices = orig.args axis = orig.attrs.axis + batch_dims = orig.attrs.batch_dims zero, one = map(make_scalar_tensor, [0, 1]) data_grad = zeros_like(data) try: @@ -726,6 +727,12 @@ def make_scalar_tensor(v): data_shape = (data_shape,) else: axis = int(axis) + if batch_dims is None: + batch_dims = 0 + else: + batch_dims = int(batch_dims) + if batch_dims != 0: + raise OpError("take_grad only supports batch_dims equales to 0") strides = [1] * len(data_shape) if len(indices.checked_type.shape) == 0: