diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index cbe989f93558..06fa13a694bc 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -401,6 +401,16 @@ struct SparseToDenseAttrs : public tvm::AttrsNode { } }; // struct SparseToDenseAttrs +/*! \brief Attributes used in sparse_reshape operator */ +struct SparseReshapeAttrs : public tvm::AttrsNode { + Array prev_shape; + Array new_shape; + TVM_DECLARE_ATTRS(SparseReshapeAttrs, "relay.attrs.SparseReshapeAttrs") { + TVM_ATTR_FIELD(prev_shape).describe("Previous shape of the dense output tensor"); + TVM_ATTR_FIELD(new_shape).describe("New Shape of the dense output tensor"); + } +}; // struct SparseReshapeAttrs + /*! \brief Attributes for ndarray_size operator */ struct NdarraySizeAttrs : public tvm::AttrsNode { DataType dtype; diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index a04762f28feb..c61f393c4b26 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -506,7 +506,7 @@ inline Array split(const Tensor& x, Array split_indices, int a begin_ids.push_back(idx); } - Array > out_shapes; + Array> out_shapes; for (size_t i = 0; i < begin_ids.size(); ++i) { PrimExpr out_axis_size; if (i == begin_ids.size() - 1) { @@ -1386,6 +1386,77 @@ inline Array meshgrid(const Array& inputs, const std::string& in return result; } +/*! + * \brief Compute new sparse indices and return them after the sparse_reshape operation + * + * \param sparse_indices Indices where values of the dense tensor exist + * \param sparse_values Values at the above indices respectively + * \param prev_shape Old Shape of the sparse tensor corresponding to sparse_indices + * \param new_shape Desired Shape of the sparse tensor which will correspond to output + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the sparse_reshape operation + */ +inline Array SparseReshape(const Tensor& sparse_indices, const Tensor& sparse_values, + Array prev_shape, Array new_shape, + const std::string name = "T_sparse_reshape", + std::string tag = kInjective) { + Array result; + int new_shape_size = new_shape.size(); + int prev_shape_size = prev_shape.size(); + Array new_sparse_indices_shape{sparse_indices->shape[0], new_shape_size}; + std::vector multipliers(prev_shape_size, 1); + std::vector dividers(new_shape_size, 1); + + tvm::PrimExpr total_ele = prev_shape[0]; + for (int i = prev_shape_size - 2; i >= 0; --i) { + multipliers[i] = prev_shape[i + 1] * multipliers[i + 1]; + total_ele *= prev_shape[i + 1]; + } + PrimExpr division_total_ele = 1; + for (int i = 0; i < new_shape_size; ++i) { + division_total_ele *= if_then_else(new_shape[i] != -1, new_shape[i], 1); + } + for (int i = new_shape_size - 2; i >= 0; --i) { + dividers[i] = dividers[i + 1] * if_then_else(new_shape[i + 1] != -1, new_shape[i + 1], + div(total_ele, division_total_ele)); + } + + result.push_back(compute( + new_sparse_indices_shape, + [&](const Array& indices) { + PrimExpr flattened_idx = 0; + if (sparse_indices->shape.size() == 1) { + flattened_idx += sparse_indices[indices[0]]; + } else { + for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) { + flattened_idx += (sparse_indices[indices[0]][k] * multipliers[k]); + } + } + Array new_sparse_indices; + if (new_shape_size != 1) { + for (int i = 0; i < new_shape_size; i++) { + new_sparse_indices.push_back(floordiv(flattened_idx, dividers[i])); + flattened_idx = floormod(flattened_idx, dividers[i]); + } + PrimExpr ret = -1; + + for (int i = 0; i < new_shape_size; i++) { + if (indices.size() == 1) { + return new_sparse_indices[0]; + } else { + ret = if_then_else(indices[1] == i, new_sparse_indices[i], ret); + } + } + return ret; + } else { + return flattened_idx; + } + }, + name, tag)); + return result; +} // namespace topi /*! * \brief Transform the layout according to \p src_layout and \p dst_layout * \param src the source input. diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d5746a38582c..e67ab1f07511 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -975,6 +975,26 @@ def _impl(inputs, attr, params, mod): return _impl +def _sparse_reshape(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + + indices_tensor = _infer_value(inputs[0], params, mod).asnumpy() + values_tensor = np.zeros(indices_tensor.shape[0], dtype=indices_tensor.dtype) + prev_shape_tensor = _infer_value(inputs[1], params, mod).asnumpy() + new_shape_tensor = _infer_value(inputs[2], params, mod).asnumpy() + + indices_data = _expr.const(indices_tensor, indices_tensor.dtype) + values_data = _expr.const(values_tensor, values_tensor.dtype) + + ret = _op.sparse_reshape( + indices_data, values_data, list(prev_shape_tensor), list(new_shape_tensor) + ) + return ret, _expr.const(new_shape_tensor, new_shape_tensor.dtype) + + return _impl + + def _identity(): def _impl(inputs, attr, params, mod): return inputs[0] @@ -2423,6 +2443,9 @@ def _impl(inputs, attr, params, mod): "SpaceToDepth": _space_to_depth(), "SparseToDense": _sparse_to_dense(), "SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(), + "SparseReshape": _sparse_reshape(), + "SparseFillEmptyRows": _sparse_reshape(), + "SparseSegmentSqrtN": _sparse_reshape(), "Split": _split(False), "SplitV": _split(True), "Sqrt": AttrCvt("sqrt"), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 05ca6d2e4bb9..b54c6b3feb6e 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -63,6 +63,8 @@ _reg.register_injective_schedule("sparse_to_dense") _reg.register_injective_schedule("matrix_set_diag") _reg.register_injective_schedule("adv_index") +_reg.register_injective_schedule("sparse_reshape") + # concatenate _reg.register_schedule("concatenate", strategy.schedule_concatenate) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 7e7f9b299593..ac7e873a6adc 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1320,3 +1320,52 @@ def adv_index(inputs): Output tensor. """ return _make.adv_index(Tuple(inputs)) + + +def sparse_reshape(sparse_indices, sparse_values, prev_shape, new_shape): + """ + Reshape a Sparse Tensor + + Parameters + ---------- + sparse_indices : relay.Expr + A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the + number of sparse values and n_dim is the number of dimensions of the dense_shape + sparse_values : relay.Expr + A 1-D tensor[N] containing the sparse values for the sparse indices. + prev_shape : relay.Expr + A 1-D tensor containing the previous shape of the dense tensor + new_shape : relay.Expr + A 1-D tensor containing the new shape of the dense tensor + + Returns + ------- + result: relay.Expr + Output tensor. + Examples + -------- + .. code-block:: python + + sparse_indices = [[0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [1, 2, 3]] + + sparse_values = [7, 5, 6, 3, 9] + + prev_shape = [2, 3, 4] + + new_shape = [9, -1] + + relay.sparse_reshape(sparse_indices, + sparse_values, + prev_shape, + new_shape) + = [[0, 0], + [0, 1], + [1, 2], + [4, 2], + [8, 1]] + """ + return _make.sparse_reshape(sparse_indices, sparse_values, prev_shape, new_shape) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 6ddbc73e4666..bda54635ec4a 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -931,3 +931,52 @@ def adv_index(data, indices): Output tensor """ return cpp.adv_index(data, indices) + + +def sparse_reshape(sparse_indices, sparse_values, prev_shape, new_shape): + """ + Reshape a Sparse Tensor + + Parameters + ---------- + sparse_indices : relay.Expr + A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the + number of sparse values and n_dim is the number of dimensions of the dense_shape + sparse_values : relay.Expr + A 1-D tensor[N] containing the sparse values for the sparse indices. + prev_shape : relay.Expr + A 1-D tensor containing the previous shape of the dense tensor + new_shape : relay.Expr + A 1-D tensor containing the new shape of the dense tensor + + Returns + ------- + result: relay.Expr + Output tensor. + Examples + -------- + .. code-block:: python + + sparse_indices = [[0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [1, 2, 3]] + + sparse_values = [7, 5, 6, 3, 9] + + prev_shape = [2, 3, 4] + + new_shape = [9, -1] + + relay.sparse_reshape(sparse_indices, + sparse_values, + prev_shape, + new_shape) + = [[0, 0], + [0, 1], + [1, 2], + [4, 2], + [8, 1]] + """ + return cpp.sparse_reshape(sparse_indices, sparse_values, prev_shape, new_shape) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 640943eac805..b52008773a14 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1553,6 +1553,53 @@ RELAY_REGISTER_OP("meshgrid") .set_attr("FTVMCompute", MeshgridCompute) .set_attr("TOpPattern", kInjective); +TVM_REGISTER_NODE_TYPE(SparseReshapeAttrs); + +bool SparseReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [sparse_indices, sparse_values, result] + ICHECK_EQ(types.size(), 3) << "SparseReshapeRel expects 3 types but provided " << types.size(); + auto sparse_indices = types[0].as(); + const auto* param = attrs.as(); + ICHECK(param != nullptr); + Array new_sparse_indices_shape{sparse_indices->shape[0], + static_cast((param->new_shape).size())}; + reporter->Assign(types[2], TensorType(new_sparse_indices_shape, sparse_indices->dtype)); + return true; +} + +Array SparseReshapeCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + ICHECK_EQ(inputs.size(), 2) << "SparseReshapeCompute expects 2 input but provided " + << inputs.size(); + const auto* param = attrs.as(); + ICHECK(param != nullptr); + return {topi::SparseReshape(inputs[0], inputs[1], param->prev_shape, param->new_shape)}; +} + +Expr MakeSparseReshape(Expr sparse_indices, Expr sparse_values, Array prev_shape, + Array new_shape) { + auto attrs = make_object(); + attrs->prev_shape = std::move(prev_shape); + attrs->new_shape = std::move(new_shape); + static const Op& op = Op::Get("sparse_reshape"); + return Call(op, {sparse_indices, sparse_values}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.sparse_reshape").set_body_typed(MakeSparseReshape); + +RELAY_REGISTER_OP("sparse_reshape") + .describe(R"code(Return new sparse indices of the reshaped tensor +)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("sparse_indices", "Tensor", "The first tensor") + .add_argument("sparse_values", "Tensor", "The second tensor") + .add_type_rel("sparse_reshape", SparseReshapeRel) + .set_attr("TOpPattern", kInjective) + .set_support_level(3) + .set_attr("FTVMCompute", SparseReshapeCompute); + // tile operator TVM_REGISTER_NODE_TYPE(TileAttrs); diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 22ed6c5b2edf..50189de527f6 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1811,6 +1811,58 @@ def test_forward_sparse_dense_matmul(): ) +def _test_sparse_reshape(indices_np, values_np, prev_shape_np, new_shape_np, dtype): + with tf.Graph().as_default(): + sp_input = tf.sparse.SparseTensor( + indices=indices_np, values=values_np, dense_shape=prev_shape_np + ) + new_shape = tf.constant(new_shape_np, new_shape_np.dtype) + # new_shape = tf.placeholder( + # shape=new_shape_np.shape, dtype=new_shape_np.dtype, name="new_shape" + # ) + + tf.sparse.reshape(sp_input, new_shape, name="sparse_reshape") + + # import pdb + + # pdb.set_trace() + compare_tf_with_tvm( + None, + "", + ["sparse_reshape:0", "sparse_reshape/Identity:0", "sparse_reshape:1"], + ) + + +def test_forward_sparse_reshape(): + """ sparse_reshape op test""" + ################################################################### + # + # In order to create a SparseTensor, it requires 3 input as below: + # SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) + # + # Above Sparse can be represented in Dense as below : + # [[1, 0, 0, 0] + # [0, 0, 2, 0] + # [0, 0, 0, 0]] + # + # ------------------------------------------------------------------ + sparse_indices_np = np.array( + [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int32 + ) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([2, 3, 6], dtype=np.int32) + new_shape_np = np.array([9, 4], dtype=np.int32) + + _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32") + sparse_indices_np = np.array( + [[0, 0, 0, 0], [0, 0, 1, 2], [0, 1, 0, 3], [1, 0, 0, 4], [1, 2, 3, 6]], dtype=np.int32 + ) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([2, 3, 6, 7], dtype=np.int32) + new_shape_np = np.array([9, -1, 7], dtype=np.int32) + _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32") + + ####################################################################### # StridedSlice # ------------ @@ -4682,4 +4734,6 @@ def lstm_cell(): if __name__ == "__main__": - pytest.main([__file__]) + # test_forward_sparse_to_dense() + test_forward_sparse_reshape() + # pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 668285dfb882..5d0a90b6300d 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1042,6 +1042,117 @@ def verify_scatter_add(dshape, ishape, axis=0, dtype="float32"): verify_scatter_add((16, 16, 4, 5), (16, 16, 4, 5), 3) +@tvm.testing.uses_gpu +def test_sparse_reshape(): + def ref_sparse_reshape( + sparse_indices: np.ndarray, + sparse_values: np.ndarray, + prev_shape: np.ndarray, + new_shape: np.ndarray, + ): + """ + This function calculates the expected output of sparseshape operator given the inputs. + """ + new_sparse_indices = np.ones( + (sparse_values.shape[0], new_shape.shape[0]), dtype=sparse_indices.dtype + ) + multipliers = np.ones(prev_shape.shape[0]) + dividers = np.ones(new_shape.shape[0]) + total_ele = np.prod(prev_shape) + division_total_ele = 1 + for i in range(new_shape.shape[0]): + if new_shape[i] == -1: + continue + division_total_ele *= new_shape[i] + for i in range(prev_shape.shape[0] - 2, -1, -1): + multipliers[i] = prev_shape[i + 1] * multipliers[i + 1] + for i in range(new_shape.shape[0] - 2, -1, -1): + if new_shape[i + 1] == -1: + dividers[i] = (total_ele // division_total_ele) * dividers[i + 1] + else: + dividers[i] = new_shape[i + 1] * dividers[i + 1] + for row_num, sparse_row in enumerate(sparse_indices): + flat_idx = 0 + if len(sparse_indices.shape) != 1: + for i, ele in enumerate(sparse_row): + flat_idx += sparse_row[i] * multipliers[i] + else: + flat_idx += sparse_row + if len(new_sparse_indices.shape) != 1: + for i in range(new_sparse_indices.shape[1]): + new_sparse_indices[row_num][i] = flat_idx // dividers[i] + flat_idx = flat_idx % dividers[i] + else: + new_sparse_indices[row_num] = flat_idx + + return new_sparse_indices + + def verify_sparse_reshape( + sparse_indices_np: np.ndarray, + sparse_values_np: np.ndarray, + prev_shape_np: np.ndarray, + new_shape_np: np.ndarray, + ): + """ + This function verifies the relay output of sparse_reshape with its expected output. + """ + sparse_indices = relay.var( + "sparse_indices", + relay.TensorType(sparse_indices_np.shape, str(sparse_indices_np.dtype)), + ) + sparse_values = relay.var( + "sparse_values", relay.TensorType(sparse_values_np.shape, str(sparse_values_np.dtype)) + ) + z = relay.op.sparse_reshape( + sparse_indices, sparse_values, list(prev_shape_np), list(new_shape_np) + ) + + func = relay.Function([sparse_indices, sparse_values], z) + + ref_res = ref_sparse_reshape( + sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np + ) + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(sparse_indices_np, sparse_values_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + sparse_indices_np = np.array( + [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int32 + ) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([2, 3, 6], dtype=np.int32) + new_shape_np = np.array([9, 4], dtype=np.int32) + verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) + + sparse_indices_np = np.array( + [[0, 0, 0, 0], [0, 0, 1, 2], [0, 1, 0, 3], [1, 0, 0, 4], [1, 2, 3, 6]], dtype=np.int32 + ) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([2, 3, 6, 7], dtype=np.int32) + new_shape_np = np.array([9, -1, 7], dtype=np.int32) + verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) + + sparse_indices_np = np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int32) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([9, 4], dtype=np.int32) + new_shape_np = np.array([2, -1, 6], dtype=np.int32) + verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) + + sparse_indices_np = np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int32) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([9, 4], dtype=np.int32) + new_shape_np = np.array([-1], dtype=np.int32) + verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) + + sparse_indices_np = np.array([0, 5, 10, 20, 24], dtype=np.int32) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([25], dtype=np.int32) + new_shape_np = np.array([5, 5], dtype=np.int32) + verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) + + @tvm.testing.uses_gpu def test_gather(): def verify_gather(data, axis, indices, ref_res): @@ -1313,6 +1424,7 @@ def verify_adv_index(data_shape, index_shapes): if __name__ == "__main__": + test_sparse_reshape() test_cast() test_zeros_ones() test_unary_identity()