diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index cbe989f93558..cb3d89eb3e01 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -401,6 +401,14 @@ struct SparseToDenseAttrs : public tvm::AttrsNode { } }; // struct SparseToDenseAttrs +/*! \brief Attributes used in sparse_segment_sqrtn operator */ +struct SparseSegmentSqrtNAttrs : public tvm::AttrsNode { + int num_segments; + TVM_DECLARE_ATTRS(SparseSegmentSqrtNAttrs, "relay.attrs.SparseSegmentSqrtNAttrs") { + TVM_ATTR_FIELD(num_segments).describe("Number of Segments in the output tensor"); + } +}; // struct SparseSegmentSqrtNAttrs + /*! \brief Attributes for ndarray_size operator */ struct NdarraySizeAttrs : public tvm::AttrsNode { DataType dtype; diff --git a/include/tvm/support/logging.h b/include/tvm/support/logging.h index d98363ea1c1b..ced1902a1bd1 100644 --- a/include/tvm/support/logging.h +++ b/include/tvm/support/logging.h @@ -139,10 +139,10 @@ constexpr const char* kTVM_INTERNAL_ERROR_MESSAGE = #define ICHECK_GE(x, y) ICHECK_BINARY_OP(_GE, >=, x, y) #define ICHECK_EQ(x, y) ICHECK_BINARY_OP(_EQ, ==, x, y) #define ICHECK_NE(x, y) ICHECK_BINARY_OP(_NE, !=, x, y) -#define ICHECK_NOTNULL(x) \ - ((x) == nullptr ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ - << tvm::kTVM_INTERNAL_ERROR_MESSAGE << __INDENT << "Check not null: " #x \ - << ' ', \ +#define ICHECK_NOTNULL(x) \ + ((x) == nullptr ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ + << tvm::kTVM_INTERNAL_ERROR_MESSAGE << ICHECK_INDENT \ + << "Check not null: " #x << ' ', \ (x) : (x)) // NOLINT(*) /*! \brief The diagnostic level, controls the printing of the message. */ diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index a04762f28feb..70e5c2061da7 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -506,7 +506,7 @@ inline Array split(const Tensor& x, Array split_indices, int a begin_ids.push_back(idx); } - Array > out_shapes; + Array> out_shapes; for (size_t i = 0; i < begin_ids.size(); ++i) { PrimExpr out_axis_size; if (i == begin_ids.size() - 1) { @@ -1386,6 +1386,57 @@ inline Array meshgrid(const Array& inputs, const std::string& in return result; } +/*! + * \brief Compute the sparse segment sum on the indices over the segment_ids divided by the sqrt + * of the length of the segment + * + * \param data A Tensor with data that will be assembled in the output. + * \param selected_indices A 1-D Tensor with indices into data. Has same rank as segment_ids. + * \param segment_ids A 1-D Tensor with indices into the output Tensor. Values should be sorted and + * can be repeated. + * \param num_segments An optional int32 scalar. Indicates the size of the output Tensor. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the sparse_segment_sqrtn operation + */ +inline Array SparseSegmentSqrtN(const Tensor& data, const Tensor& selected_indices, + const Tensor& segment_ids, int num_segments, + const std::string name = "T_sparse_segment_sqrtn", + std::string tag = kInjective) { + Array result; + Array new_data_shape; + if (num_segments != -1) { + new_data_shape.push_back(num_segments); + } else { + new_data_shape.push_back(selected_indices->shape[0]); + } + for (int i = 1; i < static_cast(data->shape.size()); ++i) { + new_data_shape.push_back(data->shape[i]); + } + auto selected_data = tvm::topi::take(data, selected_indices, 0, "clip"); + + result.push_back(compute( + new_data_shape, + [&](const Array& indices) { + PrimExpr ret = static_cast(0.0); + PrimExpr length_segment = static_cast(0.0); + for (int i = 0; i < GetConstInt(segment_ids->shape[0]); ++i) { + Array secondary_indices; + secondary_indices.push_back(i); + secondary_indices.insert(secondary_indices.end(), indices.begin() + 1, indices.end()); + PrimExpr condition = indices[0] == segment_ids[i]; + length_segment += if_then_else(condition, 1, 0); + ret += if_then_else(condition, selected_data(secondary_indices), 0); + } + // length_segment = if_then_else(length_segment == 0, 1, length_segment); + PrimExpr sqrt_length_segment = + tvm::sqrt(if_then_else(length_segment == 0, 1, length_segment)); + return div(ret, sqrt_length_segment); + }, + name, tag)); + return result; +} // namespace topi /*! * \brief Transform the layout according to \p src_layout and \p dst_layout * \param src the source input. diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d5746a38582c..3cc7f9708862 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -975,6 +975,45 @@ def _impl(inputs, attr, params, mod): return _impl +def _sparse_segment_sqrtn_with_num_segments(): + def _impl(inputs, attr, params, mod): + + assert len(inputs) == 4, "There should be 4 input tensors" + num_segments = _infer_value(inputs[3], params, mod).asnumpy().tolist() + return _op.sparse_segment_sqrtn(inputs[0], inputs[1], inputs[2], num_segments) + + return _impl + + +def _sparse_segment_sqrtn(): + def _impl(inputs, attr, params, mod): + + assert len(inputs) == 3, "There should be 3 input tensors" + result = _op.sparse_segment_sqrtn(inputs[0], inputs[1], inputs[2]) + num_segments = _op.add(get_relay_op("max")(inputs[2]), _expr.const([1])) + num_output_shape_dims = len(attr["_output_shapes"][0]) + begin_indices = _op.repeat(_expr.const([0]), num_output_shape_dims, 0) + end_indices = num_segments + if num_output_shape_dims > 1: + end_indices = _op.concatenate( + [ + end_indices, + _op.repeat(_expr.const([-1]), num_output_shape_dims - 1, 0), + ], + 0, + ) + strides = _op.repeat(_expr.const([1]), num_output_shape_dims, 0) + return _op.strided_slice( + result, + begin=begin_indices, + end=end_indices, + strides=strides, + slice_mode="size", + ) + + return _impl + + def _identity(): def _impl(inputs, attr, params, mod): return inputs[0] @@ -2423,6 +2462,8 @@ def _impl(inputs, attr, params, mod): "SpaceToDepth": _space_to_depth(), "SparseToDense": _sparse_to_dense(), "SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(), + "SparseSegmentSqrtN": _sparse_segment_sqrtn(), + "SparseSegmentSqrtNWithNumSegments": _sparse_segment_sqrtn_with_num_segments(), "Split": _split(False), "SplitV": _split(True), "Sqrt": AttrCvt("sqrt"), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 05ca6d2e4bb9..bfafdfcb17b4 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -63,6 +63,8 @@ _reg.register_injective_schedule("sparse_to_dense") _reg.register_injective_schedule("matrix_set_diag") _reg.register_injective_schedule("adv_index") +_reg.register_injective_schedule("sparse_segment_sqrtn") + # concatenate _reg.register_schedule("concatenate", strategy.schedule_concatenate) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 7e7f9b299593..18053e282f16 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1320,3 +1320,50 @@ def adv_index(inputs): Output tensor. """ return _make.adv_index(Tuple(inputs)) + + +def sparse_segment_sqrtn(data, indices, segment_ids, num_segments=None): + """ + Computes the sum along sparse segments of a tensor divided by the sqrt of the length of segment + Reference Link: https://www.tensorflow.org/api_docs/python/tf/sparse/segment_sqrt_n?hl=bn + Parameters + ---------- + data : relay.Expr + A Tensor with data that will be assembled in the output. + indices : relay.Expr + A 1-D Tensor with indices into data. Has same rank as segment_ids. + segment_ids : relay.Expr + A 1-D Tensor with indices into the output Tensor. Values should be sorted and can be + repeated. + num_segments : Optional[int] + An optional int32 scalar. Indicates the size of the output Tensor. + + Returns + ------- + result: relay.Expr + Output tensor containing the sparse segment sum + + Examples + -------- + .. code-block:: python + + data = [[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]] + + indices = [0, 1] + + segment_ids = [0, 2] + + num_segments = 4 + + result = relay.sparse_segment_sqrtn(data, + indices, + segment_ids, + num_segments) + result = [[ 1 2 3 4] + [ 0 0 0 0] + [-1 -2 -3 -4] + [ 0 0 0 0]] + """ + if not num_segments: + num_segments = -1 + return _make.sparse_segment_sqrtn(data, indices, segment_ids, num_segments) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 6ddbc73e4666..fe6d30748cdc 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -931,3 +931,50 @@ def adv_index(data, indices): Output tensor """ return cpp.adv_index(data, indices) + + +def sparse_segment_sqrtn(data, indices, segment_ids, num_segments=None): + """ + Computes the sum along sparse segments of a tensor divided by the sqrt of the length of segment + Reference Link: https://www.tensorflow.org/api_docs/python/tf/sparse/segment_sqrt_n?hl=bn + Parameters + ---------- + data : relay.Expr + A Tensor with data that will be assembled in the output. + indices : relay.Expr + A 1-D Tensor with indices into data. Has same rank as segment_ids. + segment_ids : relay.Expr + A 1-D Tensor with indices into the output Tensor. Values should be sorted and can be + repeated. + num_segments : Optional[int] + An optional int32 scalar. Indicates the size of the output Tensor. + + Returns + ------- + result: relay.Expr + Output tensor containing the sparse segment sum + + Examples + -------- + .. code-block:: python + + data = [[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]] + + indices = [0, 1] + + segment_ids = [0, 2] + + num_segments = 4 + + result = relay.sparse_segment_sqrtn(data, + indices, + segment_ids, + num_segments) + result = [[ 1 2 3 4] + [ 0 0 0 0] + [-1 -2 -3 -4] + [ 0 0 0 0]] + """ + if not num_segments: + num_segments = -1 + return cpp.sparse_segment_sqrtn(data, indices, segment_ids, num_segments) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 640943eac805..387d88801093 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1553,6 +1553,61 @@ RELAY_REGISTER_OP("meshgrid") .set_attr("FTVMCompute", MeshgridCompute) .set_attr("TOpPattern", kInjective); +TVM_REGISTER_NODE_TYPE(SparseSegmentSqrtNAttrs); + +bool SparseSegmentSqrtNRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, indices, segment_ids, result] + ICHECK_EQ(types.size(), 4) << "SparseSegmentSqrtNRel expects 4 types but " << types.size() + << " were provided"; + auto data = types[0].as(); + auto indices = types[1].as(); + const auto* param = attrs.as(); + ICHECK_NOTNULL(param); + Array new_data_shape; + if (param->num_segments != -1) { + new_data_shape.push_back(param->num_segments); + } else { + new_data_shape.push_back(indices->shape[0]); + } + for (int i = 1; i < static_cast(data->shape.size()); ++i) { + new_data_shape.push_back(data->shape[i]); + } + reporter->Assign(types[3], TensorType(new_data_shape, tvm::DataType::Float(32))); + return true; +} + +Array SparseSegmentSqrtNCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + ICHECK_EQ(inputs.size(), 3) << "SparseSegmentSqrtNCompute expects 3 input but " << inputs.size() + << " were provided"; + const auto* param = attrs.as(); + ICHECK_NOTNULL(param); + return {topi::SparseSegmentSqrtN(inputs[0], inputs[1], inputs[2], param->num_segments)}; +} + +Expr MakeSparseSegmentSqrtN(Expr data, Expr indices, Expr segment_ids, int num_segments) { + auto attrs = make_object(); + attrs->num_segments = std::move(num_segments); + static const Op& op = Op::Get("sparse_segment_sqrtn"); + return Call(op, {data, indices, segment_ids}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.sparse_segment_sqrtn").set_body_typed(MakeSparseSegmentSqrtN); + +RELAY_REGISTER_OP("sparse_segment_sqrtn") + .describe(R"code(Return sparse segment sum of the tensor given segments +)code" TVM_ADD_FILELINE) + .set_num_inputs(3) + .set_attrs_type() + .add_argument("data", "Tensor", "The input data to calculated the op on") + .add_argument("indices", "Tensor", "Selects these indices from data") + .add_argument("segment_ids", "Tensor", "A 1-D Tensor with indices into the output Tensor") + .add_type_rel("sparse_segment_sqrtn", SparseSegmentSqrtNRel) + .set_attr("TOpPattern", kInjective) + .set_support_level(3) + .set_attr("FTVMCompute", SparseSegmentSqrtNCompute); + // tile operator TVM_REGISTER_NODE_TYPE(TileAttrs); diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 22ed6c5b2edf..de40838b3cd4 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -254,6 +254,9 @@ def name_without_num(name): ) # since the names from tensorflow and relay runs are not exactly same, # first len(tf_output) will be compared + # import pdb + + # pdb.set_trace() for i in range(len(tf_output)): if not isinstance(tf_output[i], np.ndarray): assert len(tvm_output[i].shape) == 0 @@ -1811,6 +1814,65 @@ def test_forward_sparse_dense_matmul(): ) +####################################################################### +# SparseSegmentSqrtN +# ------------ + + +def _test_sparse_segment_sqrtn(data_np, indices_np, segment_ids_np, num_segments): + with tf.Graph().as_default(): + data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name="data") + indices = tf.placeholder(shape=indices_np.shape, dtype=indices_np.dtype, name="indices") + segment_ids = tf.constant(segment_ids_np, segment_ids_np.dtype) + + result = tf.sparse.segment_sqrt_n( + data, indices, segment_ids, num_segments=num_segments, name="sparse_segment_sqrtn" + ) + compare_tf_with_tvm( + [data_np, indices_np], + [data.name, indices.name], + result.name, + ) + + +def test_sparse_segment_sqrtn(): + """ sparse_segment_sqrtn test""" + + data_np = np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], dtype=np.float32) + indices_np = np.array([0, 1], dtype=np.int32) + segment_ids_np = np.array([0, 1], dtype=np.int32) + num_segments = 2 + _test_sparse_segment_sqrtn(data_np, indices_np, segment_ids_np, num_segments) + + data_np = np.array([[1, 2, 3, 4], [7, 8, 2, -4], [5, 6, 7, 8]], dtype=np.float32) + indices_np = np.array([0, 1, 2], dtype=np.int32) + segment_ids_np = np.array([0, 0, 2], dtype=np.int32) + num_segments = None + _test_sparse_segment_sqrtn(data_np, indices_np, segment_ids_np, num_segments) + + data_np = np.array( + [ + [[1, 2, 3, 4], [7, 8, 2, -4], [5, 6, 7, -8]], + [[-1, -2, -3, -4], [7, 8, 2, -4], [2, 8, -9, 4]], + [[2, 4, 7, 3], [-3, 2, 5, 7], [7, -1, 3, 6]], + [[1, 2, 3, 4], [7, 8, 2, -4], [5, 6, 7, -8]], + [[-1, -2, -3, -4], [7, 8, 2, -4], [2, 8, -9, 4]], + [[2, 4, 7, 3], [-3, 2, 5, 7], [7, -1, 3, 6]], + ], + dtype=np.float32, + ) + indices_np = np.array([0, 1, 2, 3, 4, 5], dtype=np.int32) + segment_ids_np = np.array([0, 0, 2, 2, 2, 3], dtype=np.int32) + num_segments = None + _test_sparse_segment_sqrtn(data_np, indices_np, segment_ids_np, num_segments) + + data_np = np.array([1, 2, 3, 4], dtype=np.float32) + indices_np = np.array([0, 1, 3], dtype=np.int32) + segment_ids_np = np.array([0, 1, 1], dtype=np.int32) + num_segments = None + _test_sparse_segment_sqrtn(data_np, indices_np, segment_ids_np, num_segments) + + ####################################################################### # StridedSlice # ------------ diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 668285dfb882..e17b22b05c85 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -24,6 +24,8 @@ from tvm.error import TVMError from tvm.relay import create_executor, transform from tvm.relay.testing import check_grad, run_infer_type +from typing import Optional + import tvm.testing @@ -1042,6 +1044,101 @@ def verify_scatter_add(dshape, ishape, axis=0, dtype="float32"): verify_scatter_add((16, 16, 4, 5), (16, 16, 4, 5), 3) +@tvm.testing.uses_gpu +def test_sparse_segment_sqrtn(): + def ref_sparse_segment_sqrtn( + data: np.ndarray, + indices: np.ndarray, + segment_ids: np.ndarray, + num_segments: Optional[int] = None, + ): + """ + This function calculates the expected output of sparse_segment_sqrtn operator given the inputs. + """ + selected_data = np.take(data, indices, axis=0, mode="clip") + if num_segments: + result = np.zeros(((num_segments,) + data.shape[1:]), dtype=np.float32) + else: + result = np.zeros(((indices.shape[0],) + data.shape[1:]), dtype=np.float32) + num_segments = -1 + length_segments = [0 for _ in range(max(np.max(segment_ids) + 1, num_segments))] + + for element in segment_ids: + length_segments[element] += 1 + + for row_num, element in enumerate(segment_ids): + result[element] += selected_data[row_num] + for row_num, length_segment in enumerate(length_segments): + result[row_num] /= max(np.sqrt(length_segment), 1) + + return result + + def verify_sparse_segment_sqrtn( + data_np: np.ndarray, + indices_np: np.ndarray, + segment_ids_np: np.ndarray, + num_segments: Optional[int] = None, + ): + """ + This function verifies the relay output of sparse_segment_sqrtn with its expected output. + """ + data = relay.var( + "data", + relay.TensorType(data_np.shape, str(data_np.dtype)), + ) + indices = relay.var( + "indices", + relay.TensorType(indices_np.shape, str(indices_np.dtype)), + ) + segment_ids = relay.var( + "segment_ids", relay.TensorType(segment_ids_np.shape, str(segment_ids_np.dtype)) + ) + z = relay.op.sparse_segment_sqrtn(data, indices, segment_ids, num_segments) + + func = relay.Function([data, indices, segment_ids], z) + + ref_res = ref_sparse_segment_sqrtn(data_np, indices_np, segment_ids_np, num_segments) + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(data_np, indices_np, segment_ids_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + data_np = np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], dtype=np.float32) + indices_np = np.array([0, 1], dtype=np.int32) + segment_ids_np = np.array([0, 2], dtype=np.int32) + num_segments = 4 + verify_sparse_segment_sqrtn(data_np, indices_np, segment_ids_np, num_segments) + + data_np = np.array([[1, 2, 3, 4], [7, 8, 2, -4], [5, 6, 7, 8]], dtype=np.float32) + indices_np = np.array([0, 1, 2], dtype=np.int32) + segment_ids_np = np.array([0, 0, 2], dtype=np.int32) + num_segments = 4 + verify_sparse_segment_sqrtn(data_np, indices_np, segment_ids_np, num_segments) + + data_np = np.array( + [ + [[1, 2, 3, 4], [7, 8, 2, -4], [5, 6, 7, -8]], + [[-1, -2, -3, -4], [7, 8, 2, -4], [2, 8, -9, 4]], + [[2, 4, 7, 3], [-3, 2, 5, 7], [7, -1, 3, 6]], + [[1, 2, 3, 4], [7, 8, 2, -4], [5, 6, 7, -8]], + [[-1, -2, -3, -4], [7, 8, 2, -4], [2, 8, -9, 4]], + [[2, 4, 7, 3], [-3, 2, 5, 7], [7, -1, 3, 6]], + ], + dtype=np.float32, + ) + indices_np = np.array([0, 1, 2, 3, 4, 5], dtype=np.int32) + segment_ids_np = np.array([0, 0, 2, 2, 2, 3], dtype=np.int32) + num_segments = None + verify_sparse_segment_sqrtn(data_np, indices_np, segment_ids_np, num_segments) + + data_np = np.array([1, 2, 3, 4], dtype=np.float32) + indices_np = np.array([0, 1, 3], dtype=np.int32) + segment_ids_np = np.array([0, 1, 1], dtype=np.int32) + num_segments = None + verify_sparse_segment_sqrtn(data_np, indices_np, segment_ids_np, num_segments) + + @tvm.testing.uses_gpu def test_gather(): def verify_gather(data, axis, indices, ref_res): @@ -1313,6 +1410,11 @@ def verify_adv_index(data_shape, index_shapes): if __name__ == "__main__": + test_sparse_segment_sqrtn() + import sys + + sys.exit() + test_cast() test_zeros_ones() test_unary_identity()