Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,16 @@ struct SparseToDenseAttrs : public tvm::AttrsNode<SparseToDenseAttrs> {
}
}; // struct SparseToDenseAttrs

/*! \brief Attributes used in sparse_reshape operator */
struct SparseReshapeAttrs : public tvm::AttrsNode<SparseReshapeAttrs> {
Array<Integer> prev_shape;
Array<Integer> new_shape;
TVM_DECLARE_ATTRS(SparseReshapeAttrs, "relay.attrs.SparseReshapeAttrs") {
TVM_ATTR_FIELD(prev_shape).describe("Previous shape of the dense output tensor");
TVM_ATTR_FIELD(new_shape).describe("New Shape of the dense output tensor");
}
}; // struct SparseReshapeAttrs

/*! \brief Attributes for ndarray_size operator */
struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
DataType dtype;
Expand Down
73 changes: 72 additions & 1 deletion include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int a
begin_ids.push_back(idx);
}

Array<Array<PrimExpr> > out_shapes;
Array<Array<PrimExpr>> out_shapes;
for (size_t i = 0; i < begin_ids.size(); ++i) {
PrimExpr out_axis_size;
if (i == begin_ids.size() - 1) {
Expand Down Expand Up @@ -1386,6 +1386,77 @@ inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& in
return result;
}

/*!
* \brief Compute new sparse indices and return them after the sparse_reshape operation
*
* \param sparse_indices Indices where values of the dense tensor exist
* \param sparse_values Values at the above indices respectively
* \param prev_shape Old Shape of the sparse tensor corresponding to sparse_indices
* \param new_shape Desired Shape of the sparse tensor which will correspond to output
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sparse_reshape operation
*/
inline Array<Tensor> SparseReshape(const Tensor& sparse_indices, const Tensor& sparse_values,
Array<Integer> prev_shape, Array<Integer> new_shape,
const std::string name = "T_sparse_reshape",
std::string tag = kInjective) {
Array<Tensor> result;
int new_shape_size = new_shape.size();
int prev_shape_size = prev_shape.size();
Array<PrimExpr> new_sparse_indices_shape{sparse_indices->shape[0], new_shape_size};
std::vector<PrimExpr> multipliers(prev_shape_size, 1);
std::vector<PrimExpr> dividers(new_shape_size, 1);

tvm::PrimExpr total_ele = prev_shape[0];
for (int i = prev_shape_size - 2; i >= 0; --i) {
multipliers[i] = prev_shape[i + 1] * multipliers[i + 1];
total_ele *= prev_shape[i + 1];
}
PrimExpr division_total_ele = 1;
for (int i = 0; i < new_shape_size; ++i) {
division_total_ele *= if_then_else(new_shape[i] != -1, new_shape[i], 1);
}
for (int i = new_shape_size - 2; i >= 0; --i) {
dividers[i] = dividers[i + 1] * if_then_else(new_shape[i + 1] != -1, new_shape[i + 1],
div(total_ele, division_total_ele));
}

result.push_back(compute(
new_sparse_indices_shape,
[&](const Array<Var>& indices) {
PrimExpr flattened_idx = 0;
if (sparse_indices->shape.size() == 1) {
flattened_idx += sparse_indices[indices[0]];
} else {
for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
flattened_idx += (sparse_indices[indices[0]][k] * multipliers[k]);
}
}
Array<PrimExpr> new_sparse_indices;
if (new_shape_size != 1) {
for (int i = 0; i < new_shape_size; i++) {
new_sparse_indices.push_back(floordiv(flattened_idx, dividers[i]));
flattened_idx = floormod(flattened_idx, dividers[i]);
}
PrimExpr ret = -1;

for (int i = 0; i < new_shape_size; i++) {
if (indices.size() == 1) {
return new_sparse_indices[0];
} else {
ret = if_then_else(indices[1] == i, new_sparse_indices[i], ret);
}
}
return ret;
} else {
return flattened_idx;
}
},
name, tag));
return result;
} // namespace topi
/*!
* \brief Transform the layout according to \p src_layout and \p dst_layout
* \param src the source input.
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,26 @@ def _impl(inputs, attr, params, mod):
return _impl


def _sparse_reshape():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 3, "There should be 3 input tensors"

indices_tensor = _infer_value(inputs[0], params, mod).asnumpy()
values_tensor = np.zeros(indices_tensor.shape[0], dtype=indices_tensor.dtype)
prev_shape_tensor = _infer_value(inputs[1], params, mod).asnumpy()
new_shape_tensor = _infer_value(inputs[2], params, mod).asnumpy()

indices_data = _expr.const(indices_tensor, indices_tensor.dtype)
values_data = _expr.const(values_tensor, values_tensor.dtype)

ret = _op.sparse_reshape(
indices_data, values_data, list(prev_shape_tensor), list(new_shape_tensor)
)
return ret, _expr.const(new_shape_tensor, new_shape_tensor.dtype)

return _impl


def _identity():
def _impl(inputs, attr, params, mod):
return inputs[0]
Expand Down Expand Up @@ -2423,6 +2443,9 @@ def _impl(inputs, attr, params, mod):
"SpaceToDepth": _space_to_depth(),
"SparseToDense": _sparse_to_dense(),
"SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(),
"SparseReshape": _sparse_reshape(),
"SparseFillEmptyRows": _sparse_reshape(),
"SparseSegmentSqrtN": _sparse_reshape(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
_reg.register_injective_schedule("sparse_to_dense")
_reg.register_injective_schedule("matrix_set_diag")
_reg.register_injective_schedule("adv_index")
_reg.register_injective_schedule("sparse_reshape")


# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
Expand Down
49 changes: 49 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,3 +1320,52 @@ def adv_index(inputs):
Output tensor.
"""
return _make.adv_index(Tuple(inputs))


def sparse_reshape(sparse_indices, sparse_values, prev_shape, new_shape):
"""
Reshape a Sparse Tensor

Parameters
----------
sparse_indices : relay.Expr
A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the
number of sparse values and n_dim is the number of dimensions of the dense_shape
sparse_values : relay.Expr
A 1-D tensor[N] containing the sparse values for the sparse indices.
prev_shape : relay.Expr
A 1-D tensor containing the previous shape of the dense tensor
new_shape : relay.Expr
A 1-D tensor containing the new shape of the dense tensor

Returns
-------
result: relay.Expr
Output tensor.
Examples
--------
.. code-block:: python

sparse_indices = [[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0],
[1, 2, 3]]

sparse_values = [7, 5, 6, 3, 9]

prev_shape = [2, 3, 4]

new_shape = [9, -1]

relay.sparse_reshape(sparse_indices,
sparse_values,
prev_shape,
new_shape)
= [[0, 0],
[0, 1],
[1, 2],
[4, 2],
[8, 1]]
"""
return _make.sparse_reshape(sparse_indices, sparse_values, prev_shape, new_shape)
49 changes: 49 additions & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,3 +931,52 @@ def adv_index(data, indices):
Output tensor
"""
return cpp.adv_index(data, indices)


def sparse_reshape(sparse_indices, sparse_values, prev_shape, new_shape):
"""
Reshape a Sparse Tensor

Parameters
----------
sparse_indices : relay.Expr
A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the
number of sparse values and n_dim is the number of dimensions of the dense_shape
sparse_values : relay.Expr
A 1-D tensor[N] containing the sparse values for the sparse indices.
prev_shape : relay.Expr
A 1-D tensor containing the previous shape of the dense tensor
new_shape : relay.Expr
A 1-D tensor containing the new shape of the dense tensor

Returns
-------
result: relay.Expr
Output tensor.
Examples
--------
.. code-block:: python

sparse_indices = [[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0],
[1, 2, 3]]

sparse_values = [7, 5, 6, 3, 9]

prev_shape = [2, 3, 4]

new_shape = [9, -1]

relay.sparse_reshape(sparse_indices,
sparse_values,
prev_shape,
new_shape)
= [[0, 0],
[0, 1],
[1, 2],
[4, 2],
[8, 1]]
"""
return cpp.sparse_reshape(sparse_indices, sparse_values, prev_shape, new_shape)
47 changes: 47 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,53 @@ RELAY_REGISTER_OP("meshgrid")
.set_attr<FTVMCompute>("FTVMCompute", MeshgridCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

TVM_REGISTER_NODE_TYPE(SparseReshapeAttrs);

bool SparseReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [sparse_indices, sparse_values, result]
ICHECK_EQ(types.size(), 3) << "SparseReshapeRel expects 3 types but provided " << types.size();
auto sparse_indices = types[0].as<TensorTypeNode>();
const auto* param = attrs.as<SparseReshapeAttrs>();
ICHECK(param != nullptr);
Array<PrimExpr> new_sparse_indices_shape{sparse_indices->shape[0],
static_cast<int>((param->new_shape).size())};
reporter->Assign(types[2], TensorType(new_sparse_indices_shape, sparse_indices->dtype));
return true;
}

Array<te::Tensor> SparseReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
ICHECK_EQ(inputs.size(), 2) << "SparseReshapeCompute expects 2 input but provided "
<< inputs.size();
const auto* param = attrs.as<SparseReshapeAttrs>();
ICHECK(param != nullptr);
return {topi::SparseReshape(inputs[0], inputs[1], param->prev_shape, param->new_shape)};
}

Expr MakeSparseReshape(Expr sparse_indices, Expr sparse_values, Array<Integer> prev_shape,
Array<Integer> new_shape) {
auto attrs = make_object<SparseReshapeAttrs>();
attrs->prev_shape = std::move(prev_shape);
attrs->new_shape = std::move(new_shape);
static const Op& op = Op::Get("sparse_reshape");
return Call(op, {sparse_indices, sparse_values}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.sparse_reshape").set_body_typed(MakeSparseReshape);

RELAY_REGISTER_OP("sparse_reshape")
.describe(R"code(Return new sparse indices of the reshaped tensor
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_attrs_type<SparseReshapeAttrs>()
.add_argument("sparse_indices", "Tensor", "The first tensor")
.add_argument("sparse_values", "Tensor", "The second tensor")
.add_type_rel("sparse_reshape", SparseReshapeRel)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_support_level(3)
.set_attr<FTVMCompute>("FTVMCompute", SparseReshapeCompute);

// tile operator
TVM_REGISTER_NODE_TYPE(TileAttrs);

Expand Down
56 changes: 55 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,6 +1811,58 @@ def test_forward_sparse_dense_matmul():
)


def _test_sparse_reshape(indices_np, values_np, prev_shape_np, new_shape_np, dtype):
with tf.Graph().as_default():
sp_input = tf.sparse.SparseTensor(
indices=indices_np, values=values_np, dense_shape=prev_shape_np
)
new_shape = tf.constant(new_shape_np, new_shape_np.dtype)
# new_shape = tf.placeholder(
# shape=new_shape_np.shape, dtype=new_shape_np.dtype, name="new_shape"
# )

tf.sparse.reshape(sp_input, new_shape, name="sparse_reshape")

# import pdb

# pdb.set_trace()
compare_tf_with_tvm(
None,
"",
["sparse_reshape:0", "sparse_reshape/Identity:0", "sparse_reshape:1"],
)


def test_forward_sparse_reshape():
""" sparse_reshape op test"""
###################################################################
#
# In order to create a SparseTensor, it requires 3 input as below:
# SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
#
# Above Sparse can be represented in Dense as below :
# [[1, 0, 0, 0]
# [0, 0, 2, 0]
# [0, 0, 0, 0]]
#
# ------------------------------------------------------------------
sparse_indices_np = np.array(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int32
)
sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32)
prev_shape_np = np.array([2, 3, 6], dtype=np.int32)
new_shape_np = np.array([9, 4], dtype=np.int32)

_test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32")
sparse_indices_np = np.array(
[[0, 0, 0, 0], [0, 0, 1, 2], [0, 1, 0, 3], [1, 0, 0, 4], [1, 2, 3, 6]], dtype=np.int32
)
sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32)
prev_shape_np = np.array([2, 3, 6, 7], dtype=np.int32)
new_shape_np = np.array([9, -1, 7], dtype=np.int32)
_test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32")


#######################################################################
# StridedSlice
# ------------
Expand Down Expand Up @@ -4682,4 +4734,6 @@ def lstm_cell():


if __name__ == "__main__":
pytest.main([__file__])
# test_forward_sparse_to_dense()
test_forward_sparse_reshape()
# pytest.main([__file__])
Loading