diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index a04762f28feb..56edb45f4a9b 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -506,7 +506,7 @@ inline Array split(const Tensor& x, Array split_indices, int a begin_ids.push_back(idx); } - Array > out_shapes; + Array> out_shapes; for (size_t i = 0; i < begin_ids.size(); ++i) { PrimExpr out_axis_size; if (i == begin_ids.size() - 1) { @@ -1386,6 +1386,88 @@ inline Array meshgrid(const Array& inputs, const std::string& in return result; } +/*! + * \brief Compute new sparse indices and return them after the sparse_reshape operation + * + * \param sparse_indices Indices where values of the dense tensor exist + * \param prev_shape Old Shape of the sparse tensor corresponding to sparse_indices + * \param new_shape Desired Shape of the sparse tensor which will correspond to output + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the sparse_reshape operation + */ +inline Array SparseReshape(const Tensor& sparse_indices, const Tensor& prev_shape, + const Tensor& new_shape, + const std::string name = "T_sparse_reshape", + std::string tag = kInjective) { + Array result; + Array new_sparse_indices_shape{sparse_indices->shape[0], new_shape->shape[0]}; + + int new_shape_size = GetConstInt(new_shape->shape[0]); + int prev_shape_size = GetConstInt(prev_shape->shape[0]); + std::vector multipliers(prev_shape_size, 1); + std::vector dividers(new_shape_size, 1); + + auto neg_shape_val = compute(Array{1}, [&](const Array& indices) { + tvm::PrimExpr total_ele = prev_shape[0]; + for (int i = prev_shape_size - 2; i >= 0; --i) { + multipliers[i] = prev_shape[i + 1] * multipliers[i + 1]; + total_ele *= prev_shape[i + 1]; + } + PrimExpr division_total_ele = 1; + for (int i = 0; i < new_shape_size; ++i) { + division_total_ele *= if_then_else(new_shape[i] != -1, new_shape[i], 1); + } + for (int i = new_shape_size - 2; i >= 0; --i) { + dividers[i] = dividers[i + 1] * if_then_else(new_shape[i + 1] != -1, new_shape[i + 1], + div(total_ele, division_total_ele)); + } + return div(total_ele, division_total_ele); + }); + + result.push_back(compute( + new_sparse_indices_shape, + [&](const Array& indices) { + PrimExpr flattened_idx = 0; + if (sparse_indices->shape.size() == 1) { + flattened_idx += sparse_indices[indices[0]]; + } else { + for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) { + flattened_idx += (sparse_indices[indices[0]][k] * multipliers[k]); + } + } + Array new_sparse_indices; + if (new_shape_size != 1) { + for (int i = 0; i < new_shape_size; i++) { + new_sparse_indices.push_back(floordiv(flattened_idx, dividers[i])); + flattened_idx = floormod(flattened_idx, dividers[i]); + } + PrimExpr ret = -1; + + for (int i = 0; i < new_shape_size; i++) { + if (indices.size() == 1) { + return new_sparse_indices[0]; + } else { + ret = if_then_else(indices[1] == i, new_sparse_indices[i], ret); + } + } + return ret; + } else { + return flattened_idx; + } + }, + name, tag)); + result.push_back(compute( + Array{new_shape_size}, + [&](const Array& indices) { + PrimExpr ret = new_shape(indices); + ret = if_then_else(ret == -1, neg_shape_val[0], ret); + return ret; + }, + name, tag)); + return result; +} // namespace topi /*! * \brief Transform the layout according to \p src_layout and \p dst_layout * \param src the source input. diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d5746a38582c..d3f5489bd33a 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -975,6 +975,22 @@ def _impl(inputs, attr, params, mod): return _impl +def _sparse_reshape(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + + indices_tensor = _infer_value(inputs[0], params, mod).asnumpy() + values_tensor = params["SparseTensor/values"].asnumpy() + prev_shape_tensor = _infer_value(inputs[1], params, mod).asnumpy() + new_shape = inputs[2] + indices_data = _expr.const(indices_tensor, indices_tensor.dtype) + prev_shape_data = _expr.const(prev_shape_tensor, prev_shape_tensor.dtype) + ret = _op.sparse_reshape(indices_data, prev_shape_data, new_shape).astuple() + return ret, _expr.const(values_tensor, values_tensor.dtype) + + return _impl + + def _identity(): def _impl(inputs, attr, params, mod): return inputs[0] @@ -2423,6 +2439,7 @@ def _impl(inputs, attr, params, mod): "SpaceToDepth": _space_to_depth(), "SparseToDense": _sparse_to_dense(), "SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(), + "SparseReshape": _sparse_reshape(), "Split": _split(False), "SplitV": _split(True), "Sqrt": AttrCvt("sqrt"), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 05ca6d2e4bb9..b54c6b3feb6e 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -63,6 +63,8 @@ _reg.register_injective_schedule("sparse_to_dense") _reg.register_injective_schedule("matrix_set_diag") _reg.register_injective_schedule("adv_index") +_reg.register_injective_schedule("sparse_reshape") + # concatenate _reg.register_schedule("concatenate", strategy.schedule_concatenate) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 7e7f9b299593..bf13951c2af0 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1320,3 +1320,48 @@ def adv_index(inputs): Output tensor. """ return _make.adv_index(Tuple(inputs)) + + +def sparse_reshape(sparse_indices, prev_shape, new_shape): + """ + Reshape a Sparse Tensor + + Parameters + ---------- + sparse_indices : relay.Expr + A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the + number of sparse values and n_dim is the number of dimensions of the dense_shape + prev_shape : relay.Expr + A 1-D tensor containing the previous shape of the dense tensor + new_shape : relay.Expr + A 1-D tensor containing the new shape of the dense tensor + + Returns + ------- + result: relay.Expr + Output tensor. + Examples + -------- + .. code-block:: python + + sparse_indices = [[0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [1, 2, 3]] + + prev_shape = [2, 3, 4] + + new_shape = [9, -1] + + new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices, + prev_shape, + new_shape) + new_sparse_indices = [[0, 0], + [0, 1], + [1, 2], + [4, 2], + [8, 1]] + new_shape = [9, 4] + """ + return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, new_shape), 2) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 6ddbc73e4666..9f0daff079af 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -931,3 +931,48 @@ def adv_index(data, indices): Output tensor """ return cpp.adv_index(data, indices) + + +def sparse_reshape(sparse_indices, prev_shape, new_shape): + """ + Reshape a Sparse Tensor + + Parameters + ---------- + sparse_indices : relay.Expr + A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the + number of sparse values and n_dim is the number of dimensions of the dense_shape + prev_shape : relay.Expr + A 1-D tensor containing the previous shape of the dense tensor + new_shape : relay.Expr + A 1-D tensor containing the new shape of the dense tensor + + Returns + ------- + result: relay.Expr + Output tensor. + Examples + -------- + .. code-block:: python + + sparse_indices = [[0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [1, 2, 3]] + + prev_shape = [2, 3, 4] + + new_shape = [9, -1] + + new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices, + prev_shape, + new_shape) + new_sparse_indices = [[0, 0], + [0, 1], + [1, 2], + [4, 2], + [8, 1]] + new_shape = [9, 4] + """ + return cpp.sparse_reshape(sparse_indices, prev_shape, new_shape) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 640943eac805..2d43805fa4e5 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1553,6 +1553,47 @@ RELAY_REGISTER_OP("meshgrid") .set_attr("FTVMCompute", MeshgridCompute) .set_attr("TOpPattern", kInjective); +bool SparseReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [sparse_indices, prev_shape, new_shape, result] + ICHECK_EQ(types.size(), 4) << "SparseReshapeRel expects 4 types but " << types.size() + << " provided"; + auto sparse_indices = types[0].as(); + auto new_shape = types[2].as(); + Array new_sparse_indices_shape{sparse_indices->shape[0], new_shape->shape[0]}; + std::vector fields; + fields.push_back(TensorType(new_sparse_indices_shape, sparse_indices->dtype)); + fields.push_back(TensorType(new_shape->shape, new_shape->dtype)); + reporter->Assign(types[3], TupleType(Array(fields))); + return true; +} + +Array SparseReshapeCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + ICHECK_EQ(inputs.size(), 3) << "SparseReshapeCompute expects 2 input but provided " + << inputs.size(); + return {topi::SparseReshape(inputs[0], inputs[1], inputs[2])}; +} + +Expr MakeSparseReshape(Expr sparse_indices, Expr prev_shape, Expr new_shape) { + static const Op& op = Op::Get("sparse_reshape"); + return Call(op, {sparse_indices, prev_shape, new_shape}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.sparse_reshape").set_body_typed(MakeSparseReshape); + +RELAY_REGISTER_OP("sparse_reshape") + .describe(R"code(Return new sparse indices of the reshaped tensor +)code" TVM_ADD_FILELINE) + .set_num_inputs(3) + .add_argument("sparse_indices", "Tensor", "The first tensor") + .add_argument("prev_shape", "Tensor", "The second tensor") + .add_argument("new_shape", "Tensor", "The third tensor") + .add_type_rel("sparse_reshape", SparseReshapeRel) + .set_attr("TOpPattern", kInjective) + .set_support_level(3) + .set_attr("FTVMCompute", SparseReshapeCompute); + // tile operator TVM_REGISTER_NODE_TYPE(TileAttrs); diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 22ed6c5b2edf..bf47d51b0c0e 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1811,6 +1811,76 @@ def test_forward_sparse_dense_matmul(): ) +####################################################################### +# SparseReshape +# ------------ + + +def _test_sparse_reshape(indices_np, values_np, prev_shape_np, new_shape_np, dtype): + with tf.Graph().as_default(): + sp_input = tf.sparse.SparseTensor( + indices=indices_np, values=values_np, dense_shape=prev_shape_np + ) + new_shape = tf.placeholder( + shape=new_shape_np.shape, dtype=new_shape_np.dtype, name="new_shape" + ) + + tf.sparse.reshape(sp_input, new_shape, name="sparse_reshape") + compare_tf_with_tvm( + [new_shape_np], + [new_shape.name], + ["sparse_reshape:0", "sparse_reshape:1", "sparse_reshape/Identity:0"], + ) + + +def test_forward_sparse_reshape(): + """ sparse_reshape op test""" + ################################################################### + # + # In order to create a SparseTensor, it requires 3 input as below: + # SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) + # + # Above Sparse can be represented in Dense as below : + # [[1, 0, 0, 0] + # [0, 0, 2, 0] + # [0, 0, 0, 0]] + # + # ------------------------------------------------------------------ + sparse_indices_np = np.array( + [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int32 + ) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([2, 3, 6], dtype=np.int32) + new_shape_np = np.array([9, 4], dtype=np.int32) + _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32") + + sparse_indices_np = np.array( + [[0, 0, 0, 0], [0, 0, 1, 2], [0, 1, 0, 3], [1, 0, 0, 4], [1, 2, 3, 6]], dtype=np.int32 + ) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([2, 3, 6, 7], dtype=np.int32) + new_shape_np = np.array([9, -1, 7], dtype=np.int32) + _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32") + + sparse_indices_np = np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int32) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([9, 4], dtype=np.int32) + new_shape_np = np.array([2, -1, 6], dtype=np.int32) + _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32") + + sparse_indices_np = np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int32) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([9, 4], dtype=np.int32) + new_shape_np = np.array([-1], dtype=np.int32) + _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32") + + sparse_indices_np = np.array([[0], [5], [10], [20], [24]], dtype=np.int32) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([25], dtype=np.int32) + new_shape_np = np.array([5, 5], dtype=np.int32) + _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, "int32") + + ####################################################################### # StridedSlice # ------------ diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 668285dfb882..8bce014e8bf6 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1042,6 +1042,121 @@ def verify_scatter_add(dshape, ishape, axis=0, dtype="float32"): verify_scatter_add((16, 16, 4, 5), (16, 16, 4, 5), 3) +@tvm.testing.uses_gpu +def test_sparse_reshape(): + def ref_sparse_reshape( + sparse_indices: np.ndarray, + prev_shape: np.ndarray, + new_shape: np.ndarray, + ): + """ + This function calculates the expected output of sparseshape operator given the inputs. + """ + new_sparse_indices = np.ones( + (sparse_indices.shape[0], new_shape.shape[0]), dtype=sparse_indices.dtype + ) + multipliers = np.ones(prev_shape.shape[0]) + dividers = np.ones(new_shape.shape[0]) + total_ele = np.prod(prev_shape) + division_total_ele = 1 + for i in range(new_shape.shape[0]): + if new_shape[i] == -1: + continue + division_total_ele *= new_shape[i] + for i in range(prev_shape.shape[0] - 2, -1, -1): + multipliers[i] = prev_shape[i + 1] * multipliers[i + 1] + + for i in range(len(new_shape)): + if new_shape[i] == -1: + new_shape[i] = total_ele // division_total_ele + + for i in range(new_shape.shape[0] - 2, -1, -1): + dividers[i] = new_shape[i + 1] * dividers[i + 1] + + for row_num, sparse_row in enumerate(sparse_indices): + flat_idx = 0 + if len(sparse_indices.shape) != 1: + for i, ele in enumerate(sparse_row): + flat_idx += sparse_row[i] * multipliers[i] + else: + flat_idx += sparse_row + if len(new_sparse_indices.shape) != 1: + for i in range(new_sparse_indices.shape[1]): + new_sparse_indices[row_num][i] = flat_idx // dividers[i] + flat_idx = flat_idx % dividers[i] + else: + new_sparse_indices[row_num] = flat_idx + + return new_sparse_indices, new_shape + + def verify_sparse_reshape( + sparse_indices_np: np.ndarray, + sparse_values_np: np.ndarray, + prev_shape_np: np.ndarray, + new_shape_np: np.ndarray, + ): + """ + This function verifies the relay output of sparse_reshape with its expected output. + """ + sparse_indices = relay.var( + "sparse_indices", + relay.TensorType(sparse_indices_np.shape, str(sparse_indices_np.dtype)), + ) + prev_shape = relay.var( + "prev_shape", relay.TensorType(prev_shape_np.shape, str(prev_shape_np.dtype)) + ) + new_shape = relay.var( + "new_shape", relay.TensorType(new_shape_np.shape, str(new_shape_np.dtype)) + ) + z = relay.op.sparse_reshape(sparse_indices, prev_shape, new_shape).astuple() + + func = relay.Function([sparse_indices, prev_shape, new_shape], z) + + ref_res = ref_sparse_reshape(sparse_indices_np, prev_shape_np, new_shape_np) + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(sparse_indices_np, prev_shape_np, new_shape_np) + for op_res_item, ref_res_item in zip(op_res, ref_res): + tvm.testing.assert_allclose( + op_res_item.asnumpy(), ref_res_item, rtol=1e-5, atol=1e-5 + ) + + sparse_indices_np = np.array( + [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], dtype=np.int32 + ) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([2, 3, 6], dtype=np.int32) + new_shape_np = np.array([9, 4], dtype=np.int32) + verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) + + sparse_indices_np = np.array( + [[0, 0, 0, 0], [0, 0, 1, 2], [0, 1, 0, 3], [1, 0, 0, 4], [1, 2, 3, 6]], dtype=np.int32 + ) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([2, 3, 6, 7], dtype=np.int32) + new_shape_np = np.array([9, -1, 7], dtype=np.int32) + verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) + + sparse_indices_np = np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int32) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([9, 4], dtype=np.int32) + new_shape_np = np.array([2, -1, 6], dtype=np.int32) + verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) + + sparse_indices_np = np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int32) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([9, 4], dtype=np.int32) + new_shape_np = np.array([-1], dtype=np.int32) + verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) + + sparse_indices_np = np.array([[0], [5], [10], [20], [24]], dtype=np.int32) + sparse_values_np = np.array([7, 5, 6, 3, 9], dtype=np.int32) + prev_shape_np = np.array([25], dtype=np.int32) + new_shape_np = np.array([5, 5], dtype=np.int32) + verify_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np) + + @tvm.testing.uses_gpu def test_gather(): def verify_gather(data, axis, indices, ref_res): @@ -1313,6 +1428,7 @@ def verify_adv_index(data_shape, index_shapes): if __name__ == "__main__": + test_sparse_reshape() test_cast() test_zeros_ones() test_unary_identity()