Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
81e4c7f
Fix trt Test
codeislife99 Dec 2, 2020
8e2ce9a
Fixed stuff
Dec 2, 2020
d719346
Test TRT
Dec 2, 2020
4e160e9
Done
Dec 2, 2020
3104113
fix 0
Dec 2, 2020
24116fd
Trigger Build
Dec 3, 2020
e8b223f
Done
Dec 15, 2020
ff29e0c
SparseReshapeOp
Dec 17, 2020
fe3f7de
Remove Build Module changes
Dec 17, 2020
4fd2e57
Merge
Dec 17, 2020
3f5de52
Reset non-op changes
Dec 17, 2020
a521c1b
Remove stuff
Dec 17, 2020
04da7d4
More changes
Dec 17, 2020
c32b2dd
Op Level 3
Dec 17, 2020
fa5def3
Make Op changes only
Dec 17, 2020
dc8d1ce
Formatting Changes
Dec 17, 2020
2d48888
Only Transform changes
Dec 17, 2020
2e017fd
Correct Clang format version
Dec 17, 2020
b7000ac
Reset_to_clang-format-10
Dec 17, 2020
ea354d4
Merge Main
Dec 17, 2020
f14672a
Remove SparseFill Changes
Dec 17, 2020
0690155
PR stuff
Dec 18, 2020
8c1f1f4
Done
Dec 18, 2020
2730eff
Done
Dec 18, 2020
58bae15
Add Brief;
Dec 18, 2020
b1cbce0
Black
Dec 18, 2020
bee77e0
Address comments
Dec 21, 2020
5e08bcf
Address PR COmments
Dec 21, 2020
9954551
Change op Name to sparse_reshape
Dec 22, 2020
7508985
PR Comments
Dec 22, 2020
1072e58
Initial Code for Sparse Segment Sum
Dec 22, 2020
e143ffd
Done main work
Dec 22, 2020
3d1a60c
try something
Dec 22, 2020
3c73e40
Black
Dec 22, 2020
526af81
Add docstring
Dec 22, 2020
86558eb
Finish Test Cases
Dec 22, 2020
43b5f6b
SparseSegmentSumSqrtN
Jan 2, 2021
3ebdaa7
Done
Jan 2, 2021
31a1451
Add TF Frontend
Jan 2, 2021
7edf7e9
Op Level Testing
Jan 2, 2021
fb91d96
CI
Jan 4, 2021
e35122e
Update src/relay/op/tensor/transform.cc
codeislife99 Jan 4, 2021
3cd4e7c
Merge branch 'SparseSegmentSumOp' of github.com:codeislife99/incubato…
Jan 4, 2021
2e4a1ba
Update Description
Jan 4, 2021
166b61a
Logging.h
Jan 4, 2021
126d2b0
CI
Jan 4, 2021
c3aea20
Update src/relay/op/tensor/transform.cc
codeislife99 Jan 4, 2021
3b64901
Address comments
Jan 4, 2021
b7120f0
Comments
Jan 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,14 @@ struct SparseToDenseAttrs : public tvm::AttrsNode<SparseToDenseAttrs> {
}
}; // struct SparseToDenseAttrs

/*! \brief Attributes used in sparse_segment_sqrtn operator */
struct SparseSegmentSqrtNAttrs : public tvm::AttrsNode<SparseSegmentSqrtNAttrs> {
int num_segments;
TVM_DECLARE_ATTRS(SparseSegmentSqrtNAttrs, "relay.attrs.SparseSegmentSqrtNAttrs") {
TVM_ATTR_FIELD(num_segments).describe("Number of Segments in the output tensor");
}
}; // struct SparseSegmentSqrtNAttrs

/*! \brief Attributes for ndarray_size operator */
struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
DataType dtype;
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/support/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ constexpr const char* kTVM_INTERNAL_ERROR_MESSAGE =
#define ICHECK_GE(x, y) ICHECK_BINARY_OP(_GE, >=, x, y)
#define ICHECK_EQ(x, y) ICHECK_BINARY_OP(_EQ, ==, x, y)
#define ICHECK_NE(x, y) ICHECK_BINARY_OP(_NE, !=, x, y)
#define ICHECK_NOTNULL(x) \
((x) == nullptr ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \
<< tvm::kTVM_INTERNAL_ERROR_MESSAGE << __INDENT << "Check not null: " #x \
<< ' ', \
#define ICHECK_NOTNULL(x) \
((x) == nullptr ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \
<< tvm::kTVM_INTERNAL_ERROR_MESSAGE << ICHECK_INDENT \
<< "Check not null: " #x << ' ', \
(x) : (x)) // NOLINT(*)

/*! \brief The diagnostic level, controls the printing of the message. */
Expand Down
53 changes: 52 additions & 1 deletion include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int a
begin_ids.push_back(idx);
}

Array<Array<PrimExpr> > out_shapes;
Array<Array<PrimExpr>> out_shapes;
for (size_t i = 0; i < begin_ids.size(); ++i) {
PrimExpr out_axis_size;
if (i == begin_ids.size() - 1) {
Expand Down Expand Up @@ -1386,6 +1386,57 @@ inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& in
return result;
}

/*!
* \brief Compute the sparse segment sum on the indices over the segment_ids divided by the sqrt
* of the length of the segment
*
* \param data A Tensor with data that will be assembled in the output.
* \param selected_indices A 1-D Tensor with indices into data. Has same rank as segment_ids.
* \param segment_ids A 1-D Tensor with indices into the output Tensor. Values should be sorted and
* can be repeated.
* \param num_segments An optional int32 scalar. Indicates the size of the output Tensor.
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sparse_segment_sqrtn operation
*/
inline Array<Tensor> SparseSegmentSqrtN(const Tensor& data, const Tensor& selected_indices,
const Tensor& segment_ids, int num_segments,
const std::string name = "T_sparse_segment_sqrtn",
std::string tag = kInjective) {
Array<Tensor> result;
Array<PrimExpr> new_data_shape;
if (num_segments != -1) {
new_data_shape.push_back(num_segments);
} else {
new_data_shape.push_back(selected_indices->shape[0]);
}
for (int i = 1; i < static_cast<int>(data->shape.size()); ++i) {
new_data_shape.push_back(data->shape[i]);
}
auto selected_data = tvm::topi::take(data, selected_indices, 0, "clip");

result.push_back(compute(
new_data_shape,
[&](const Array<Var>& indices) {
PrimExpr ret = static_cast<float>(0.0);
PrimExpr length_segment = static_cast<float>(0.0);
for (int i = 0; i < GetConstInt(segment_ids->shape[0]); ++i) {
Array<PrimExpr> secondary_indices;
secondary_indices.push_back(i);
secondary_indices.insert(secondary_indices.end(), indices.begin() + 1, indices.end());
PrimExpr condition = indices[0] == segment_ids[i];
length_segment += if_then_else(condition, 1, 0);
ret += if_then_else(condition, selected_data(secondary_indices), 0);
}
// length_segment = if_then_else(length_segment == 0, 1, length_segment);
PrimExpr sqrt_length_segment =
tvm::sqrt(if_then_else(length_segment == 0, 1, length_segment));
return div(ret, sqrt_length_segment);
},
name, tag));
return result;
} // namespace topi
/*!
* \brief Transform the layout according to \p src_layout and \p dst_layout
* \param src the source input.
Expand Down
41 changes: 41 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,45 @@ def _impl(inputs, attr, params, mod):
return _impl


def _sparse_segment_sqrtn_with_num_segments():
def _impl(inputs, attr, params, mod):

assert len(inputs) == 4, "There should be 4 input tensors"
num_segments = _infer_value(inputs[3], params, mod).asnumpy().tolist()
return _op.sparse_segment_sqrtn(inputs[0], inputs[1], inputs[2], num_segments)

return _impl


def _sparse_segment_sqrtn():
def _impl(inputs, attr, params, mod):

assert len(inputs) == 3, "There should be 3 input tensors"
result = _op.sparse_segment_sqrtn(inputs[0], inputs[1], inputs[2])
num_segments = _op.add(get_relay_op("max")(inputs[2]), _expr.const([1]))
num_output_shape_dims = len(attr["_output_shapes"][0])
begin_indices = _op.repeat(_expr.const([0]), num_output_shape_dims, 0)
end_indices = num_segments
if num_output_shape_dims > 1:
end_indices = _op.concatenate(
[
end_indices,
_op.repeat(_expr.const([-1]), num_output_shape_dims - 1, 0),
],
0,
)
strides = _op.repeat(_expr.const([1]), num_output_shape_dims, 0)
return _op.strided_slice(
result,
begin=begin_indices,
end=end_indices,
strides=strides,
slice_mode="size",
)

return _impl


def _identity():
def _impl(inputs, attr, params, mod):
return inputs[0]
Expand Down Expand Up @@ -2423,6 +2462,8 @@ def _impl(inputs, attr, params, mod):
"SpaceToDepth": _space_to_depth(),
"SparseToDense": _sparse_to_dense(),
"SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(),
"SparseSegmentSqrtN": _sparse_segment_sqrtn(),
"SparseSegmentSqrtNWithNumSegments": _sparse_segment_sqrtn_with_num_segments(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
_reg.register_injective_schedule("sparse_to_dense")
_reg.register_injective_schedule("matrix_set_diag")
_reg.register_injective_schedule("adv_index")
_reg.register_injective_schedule("sparse_segment_sqrtn")


# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
Expand Down
47 changes: 47 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,3 +1320,50 @@ def adv_index(inputs):
Output tensor.
"""
return _make.adv_index(Tuple(inputs))


def sparse_segment_sqrtn(data, indices, segment_ids, num_segments=None):
"""
Computes the sum along sparse segments of a tensor divided by the sqrt of the length of segment
Reference Link: https://www.tensorflow.org/api_docs/python/tf/sparse/segment_sqrt_n?hl=bn
Parameters
----------
data : relay.Expr
A Tensor with data that will be assembled in the output.
indices : relay.Expr
A 1-D Tensor with indices into data. Has same rank as segment_ids.
segment_ids : relay.Expr
A 1-D Tensor with indices into the output Tensor. Values should be sorted and can be
repeated.
num_segments : Optional[int]
An optional int32 scalar. Indicates the size of the output Tensor.

Returns
-------
result: relay.Expr
Output tensor containing the sparse segment sum

Examples
--------
.. code-block:: python

data = [[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]

indices = [0, 1]

segment_ids = [0, 2]

num_segments = 4

result = relay.sparse_segment_sqrtn(data,
indices,
segment_ids,
num_segments)
result = [[ 1 2 3 4]
[ 0 0 0 0]
[-1 -2 -3 -4]
[ 0 0 0 0]]
"""
if not num_segments:
num_segments = -1
return _make.sparse_segment_sqrtn(data, indices, segment_ids, num_segments)
47 changes: 47 additions & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,3 +931,50 @@ def adv_index(data, indices):
Output tensor
"""
return cpp.adv_index(data, indices)


def sparse_segment_sqrtn(data, indices, segment_ids, num_segments=None):
"""
Computes the sum along sparse segments of a tensor divided by the sqrt of the length of segment
Reference Link: https://www.tensorflow.org/api_docs/python/tf/sparse/segment_sqrt_n?hl=bn
Parameters
----------
data : relay.Expr
A Tensor with data that will be assembled in the output.
indices : relay.Expr
A 1-D Tensor with indices into data. Has same rank as segment_ids.
segment_ids : relay.Expr
A 1-D Tensor with indices into the output Tensor. Values should be sorted and can be
repeated.
num_segments : Optional[int]
An optional int32 scalar. Indicates the size of the output Tensor.

Returns
-------
result: relay.Expr
Output tensor containing the sparse segment sum

Examples
--------
.. code-block:: python

data = [[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]

indices = [0, 1]

segment_ids = [0, 2]

num_segments = 4

result = relay.sparse_segment_sqrtn(data,
indices,
segment_ids,
num_segments)
result = [[ 1 2 3 4]
[ 0 0 0 0]
[-1 -2 -3 -4]
[ 0 0 0 0]]
"""
if not num_segments:
num_segments = -1
return cpp.sparse_segment_sqrtn(data, indices, segment_ids, num_segments)
55 changes: 55 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,61 @@ RELAY_REGISTER_OP("meshgrid")
.set_attr<FTVMCompute>("FTVMCompute", MeshgridCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

TVM_REGISTER_NODE_TYPE(SparseSegmentSqrtNAttrs);

bool SparseSegmentSqrtNRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, indices, segment_ids, result]
ICHECK_EQ(types.size(), 4) << "SparseSegmentSqrtNRel expects 4 types but " << types.size()
<< " were provided";
auto data = types[0].as<TensorTypeNode>();
auto indices = types[1].as<TensorTypeNode>();
const auto* param = attrs.as<SparseSegmentSqrtNAttrs>();
ICHECK_NOTNULL(param);
Array<PrimExpr> new_data_shape;
if (param->num_segments != -1) {
new_data_shape.push_back(param->num_segments);
} else {
new_data_shape.push_back(indices->shape[0]);
}
for (int i = 1; i < static_cast<int>(data->shape.size()); ++i) {
new_data_shape.push_back(data->shape[i]);
}
reporter->Assign(types[3], TensorType(new_data_shape, tvm::DataType::Float(32)));
return true;
}

Array<te::Tensor> SparseSegmentSqrtNCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
ICHECK_EQ(inputs.size(), 3) << "SparseSegmentSqrtNCompute expects 3 input but " << inputs.size()
<< " were provided";
const auto* param = attrs.as<SparseSegmentSqrtNAttrs>();
ICHECK_NOTNULL(param);
return {topi::SparseSegmentSqrtN(inputs[0], inputs[1], inputs[2], param->num_segments)};
}

Expr MakeSparseSegmentSqrtN(Expr data, Expr indices, Expr segment_ids, int num_segments) {
auto attrs = make_object<SparseSegmentSqrtNAttrs>();
attrs->num_segments = std::move(num_segments);
static const Op& op = Op::Get("sparse_segment_sqrtn");
return Call(op, {data, indices, segment_ids}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.sparse_segment_sqrtn").set_body_typed(MakeSparseSegmentSqrtN);

RELAY_REGISTER_OP("sparse_segment_sqrtn")
.describe(R"code(Return sparse segment sum of the tensor given segments
)code" TVM_ADD_FILELINE)
.set_num_inputs(3)
.set_attrs_type<SparseSegmentSqrtNAttrs>()
.add_argument("data", "Tensor", "The input data to calculated the op on")
.add_argument("indices", "Tensor", "Selects these indices from data")
.add_argument("segment_ids", "Tensor", "A 1-D Tensor with indices into the output Tensor")
.add_type_rel("sparse_segment_sqrtn", SparseSegmentSqrtNRel)
.set_attr<TOpPattern>("TOpPattern", kInjective)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbrookhart Is kInjective the correct pattern here?

.set_support_level(3)
.set_attr<FTVMCompute>("FTVMCompute", SparseSegmentSqrtNCompute);

// tile operator
TVM_REGISTER_NODE_TYPE(TileAttrs);

Expand Down
Loading