diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index b2555de6d35e..3fdbd7205a9d 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -937,7 +937,15 @@ struct DenseAttrs : public tvm::AttrsNode { /*! \brief Attributes for sparse_dense operator */ struct SparseDenseAttrs : public tvm::AttrsNode { - TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {} + bool sparse_lhs; + + TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") { + TVM_ATTR_FIELD(sparse_lhs) + .set_default(false) + .describe( + "Indicate whether sparse matrix is multiplied on the right or the left. If true, then " + "the operation is S * D^T (D dense, S sparse). If false, the operation is D * S^T"); + } }; /*! \brief Attributes for sparse_transpose operator */ diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c6079b4535c4..aa87c2284697 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -916,6 +916,13 @@ def _impl(inputs, attr, params, mod): data = inputs[3] + # By default, in tensorflow the first input ,i.e., data is sparse + sparse_lhs = True + + # If both are true means First input was dense and second was sparse + if attr.get("adjoint_a") and attr.get("adjoint_b"): + sparse_lhs = False + rows = [x[0] for x in indices_tensor] cols = [x[1] for x in indices_tensor] @@ -923,21 +930,31 @@ def _impl(inputs, attr, params, mod): weight_sp = csr_matrix( (values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist()) ) - weight_sp = csr_matrix(weight_sp.transpose()) + + if sparse_lhs: + data = _op.transpose(data) + else: + weight_sp = csr_matrix(weight_sp.transpose()) weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype) weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype) weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype) - ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs]) + ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs], sparse_lhs) - # If both are true means First input was dense and second was sparse - # TODO(ANSHUMAN87): Support other adjoint option too - if attr.get("adjoint_a") and attr.get("adjoint_b"): + if not sparse_lhs: ret = _op.transpose(ret) - else: + + # Case 1. If both are true means first input was dense and second was sparse + # Case 2. If both are false means first input was sparse and second was dense + # TODO(ANSHUMAN87): Support other adjoint option too + if not ( + (attr.get("adjoint_a") and attr.get("adjoint_b")) + or ((not attr.get("adjoint_a")) and (not attr.get("adjoint_b"))) + ): raise tvm.error.OpAttributeUnImplemented( "Only tf.sparse.sparse_dense_matmul() with adjoint_a=True and adjoint_b=True" + "or with adjoint_a=False and adjoint_b=False" " is supported, but adjoint_a={} and adjoint_b={} was supplied.".format( attr.get("adjoint_a"), attr.get("adjoint_b") ) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c235f87d1e99..93149b5fa1f4 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -69,7 +69,7 @@ def compute_fifo_buffer(attrs, inputs, out_type): @reg.register_compute("nn.sparse_dense") def compute_sparse_dense(attrs, inputs, out_type): """Compute definition of sparse_dense""" - return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])] + return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3], attrs["sparse_lhs"])] reg.register_strategy("nn.sparse_dense", strategy.sparse_dense_strategy) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 4810bdc35bbd..eb6ff45d2f12 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1993,17 +1993,27 @@ def batch_matmul(x, y): return _make.batch_matmul(x, y) -def sparse_dense(data, weight): +# pylint: disable=no-else-return,inconsistent-return-statements +def sparse_dense(dense_mat, sparse_mat, sparse_lhs=False): r""" - Computes the matrix multiplication of `data` and `weight`, where `data` is - a dense matrix and `weight` is a sparse (either BSR or CSR) namedtuple with + Computes the matrix multiplication of `dense_mat` and `sparse_mat`, where `dense_mat` is + a dense matrix and `sparse_mat` is a sparse (either BSR or CSR) namedtuple with fields `data`, `indices`, and `indptr`. - .. math:: + \if sparse_lhs=False: + .. math:: + + \mbox{sparse_dense}(dense_mat, sparse_mat)[m, n] + = \mbox{matmul}(D, \mbox{as_dense}(S)^T)[m, n] - \mbox{sparse_dense}(data, weight)[m, n] = \mbox{matmul}(x, \mbox{as_dense}(weight)^T)[m, n] + \if sparse_lhs=True: + .. math:: - where `as_dense` returns dense equivalent of the given sparse matrix. + \mbox{sparse_dense}(dense_mat, sparse_mat)[m, n] + = \mbox{matmul}(\mbox{as_dense}(S), (D)^T)[m, n] + + where `as_dense` returns dense equivalent of the given S(sparse matrix) + while performing matmul with given D(dense matrix). See https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html @@ -2013,20 +2023,28 @@ def sparse_dense(data, weight): Parameters ---------- - data : tvm.relay.Expr - The input data for the matrix multiplication + dense_mat : tvm.relay.Expr + The input dense matrix for the matrix multiplication - weight : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]]. - The sparse weight matrix for the matrix multiplication. + sparse_mat : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]]. + The input sparse matrix for the matrix multiplication. + + sparse_lhs : bool, optional + Indicates whether lhs or rhs matrix is sparse. Default value is False. Returns ------- result: tvm.relay.Expr The computed result. """ - if hasattr(weight, "indices"): - return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr) - return _make.sparse_dense(data, weight[0], weight[1], weight[2]) + if hasattr(sparse_mat, "indices"): + return _make.sparse_dense( + dense_mat, sparse_mat.data, sparse_mat.indices, sparse_mat.indptr, sparse_lhs + ) + else: + return _make.sparse_dense( + dense_mat, sparse_mat[0], sparse_mat[1], sparse_mat[2], sparse_lhs + ) def sparse_transpose(x): diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index bdefbcb79009..7b2b3e35e077 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -706,7 +706,7 @@ def wrap_compute_sparse_dense(topi_compute): """wrap sparse dense topi compute""" def _compute_sparse_dense(attrs, inputs, out_type): - return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])] + return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3], attrs["sparse_lhs"])] return _compute_sparse_dense diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index ebac5517d46c..c59e6887d47e 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -65,10 +65,11 @@ def schedule_sparse_dense(outs): # pylint:disable=invalid-name s = te.create_schedule([x.op for x in outs]) + # TODO(ANSHUMAN87): Add for sparse_dense_bsrmm_v1 also def _callback(op): - if op.tag == "sparse_dense_bsrmm": + if op.tag == "sparse_dense_bsrmm_v2": y_bsrmm = op.input_tensors[0] - assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block" + assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v2" out = s.outputs[0].output(0) if op not in s.outputs: @@ -362,6 +363,7 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type): sparse_dense implementation for one that operates on a padded matrix. We also padd the matrix. """ + # TODO(ANSHUMAN87): Handle for sparse_lhs case too if ( isinstance(inputs[1], relay.Constant) and isinstance(inputs[2], relay.Constant) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 55b3e6a7d1e5..94d6d9a16330 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -23,7 +23,7 @@ from ..utils import get_const_tuple -def sparse_dense(data, weight_data, weight_indices, weight_indptr): +def sparse_dense_v2(data, weight_data, weight_indices, weight_indptr): """ Computes sparse-dense matrix multiplication of `data` and `(weight_data, weight_indices, weight_indptr).T` @@ -52,13 +52,104 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr): """ assert len(weight_data.shape) in (1, 3) if len(weight_data.shape) == 1: - func = _sparse_dense_csrmm + func = _sparse_dense_csrmm_v2 if len(weight_data.shape) == 3: - func = _sparse_dense_bsrmm + func = _sparse_dense_bsrmm_v2 return func(data, weight_data, weight_indices, weight_indptr) -def _sparse_dense_csrmm(data, weight_data, weight_indices, weight_indptr): +def sparse_dense_v1(data_data, data_indices, data_indptr, weight): + """ + Computes sparse-dense matrix multiplication of + `(data_data, data_indices, data_indptr)` and `weight.T` + + Parameters + ---------- + data_data: + 1-D with shape [nnz] (CSR) or + 3-D with shape [num_blocks, bs_r, bs_c] (BSR) + + data_indices: + 1-D with shape [nnz] (CSR) or + 1-D with shape [num_blocks] (BSR) + + data_indptr: + 1-D with shape [M + 1] (CSR) or + 1-D with shape [(M + 1) // bs_r] (BSR) + + weight: + 2-D with shape [N, K], float32 + + Returns + ------- + output : tvm.te.Tensor + 2-D with shape [M, N] + """ + assert len(data_data.shape) in (1, 3) + if len(data_data.shape) == 1: + func = _sparse_dense_csrmm_v1 + if len(data_data.shape) == 3: + func = _sparse_dense_bsrmm_v1 + return func(data_data, data_indices, data_indptr, weight) + + +# pylint: disable=no-else-return,inconsistent-return-statements +def sparse_dense(dense_data, sparse_data, sparse_indices, sparse_indptr, sparse_lhs=False): + """ + Computes sparse-dense matrix multiplication of `data` and + `(weight_data, weight_indices, weight_indptr).T`, if sparse_lhs=False + or + Computes sparse-dense matrix multiplication of + `(data_data, data_indices, data_indptr)` and `weight.T`, if sparse_lhs=True + + Parameters + ---------- + dense_data : tvm.te.Tensor + 2-D with shape [M, K], float32 + + sparse_data : tvm.te.Tensor + 1-D with shape [nnz] (CSR) or + 3-D with shape [num_blocks, bs_r, bs_c] (BSR) + + sparse_indices : tvm.te.Tensor + 1-D with shape [nnz] (CSR) or + 1-D with shape [num_blocks] (BSR) + + sparse_indptr : tvm.te.Tensor + 1-D with shape [N + 1] (CSR) or + 1-D with shape [(N + 1) // bs_r] (BSR) + + sparse_lhs : bool, optional + Indicates whether lhs or rhs matrix is sparse. Default value is False. + + Returns + ------- + output : tvm.te.Tensor + 2-D with shape [M, N] + """ + if sparse_lhs: + return sparse_dense_v1(sparse_data, sparse_indices, sparse_indptr, dense_data) + else: + return sparse_dense_v2(dense_data, sparse_data, sparse_indices, sparse_indptr) + + +def _sparse_dense_csrmm_v1(data_data, data_indices, data_indptr, weight): + oshape = (get_const_tuple(data_indptr.shape)[0] - 1, get_const_tuple(weight.shape)[0]) + + def f(row, i): + row_start = data_indptr[row] + row_end = data_indptr[row + 1] + row_elems = row_end - row_start + elem_idx = te.reduce_axis((0, row_elems), name="elem_idx") + elem = row_start + elem_idx + a_val = data_data[elem] + weight_val = weight[i, data_indices[elem]] + return te.sum(a_val * weight_val, axis=elem_idx) + + return te.compute(oshape, f, tag="sparse_dense_csrmm_v1") + + +def _sparse_dense_csrmm_v2(data, weight_data, weight_indices, weight_indptr): oshape = (get_const_tuple(data.shape)[0], get_const_tuple(weight_indptr.shape)[0] - 1) def f(i, row): @@ -71,10 +162,41 @@ def f(i, row): weight_val = data[i, weight_indices[elem]] return te.sum(a_val * weight_val, axis=elem_idx) - return te.compute(oshape, f, tag="sparse_dense_csrmm") + return te.compute(oshape, f, tag="sparse_dense_csrmm_v2") -def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr): +def _sparse_dense_bsrmm_v1(data_data, data_indices, data_indptr, weight): + (m, _) = get_const_tuple(weight.shape) + (_, bs_r, bs_c) = get_const_tuple(data_data.shape) + (num_blocks_plus_1,) = get_const_tuple(data_indptr.shape) + num_blocks = num_blocks_plus_1 - 1 + + def _compute_block(nb_j, j, i): + row_start = data_indptr[nb_j] + row_end = data_indptr[nb_j + 1] + row_elems = row_end - row_start + elem_idx = te.reduce_axis((0, row_elems), name="elem_idx") + block_offset = row_start + elem_idx + c = te.reduce_axis((0, bs_c), name="c") + block_j = data_indices[block_offset] + block_ij_val = data_data[block_offset][j][c] + x_val = weight[i, bs_c * block_j + c] + return te.sum(block_ij_val * x_val, axis=[elem_idx, c]) + + idxd = tvm.tir.indexdiv + idxm = tvm.tir.indexmod + + bsrmm_block = te.compute( + (num_blocks, bs_r, m), _compute_block, tag="sparse_dense_bsrmm_block_v1" + ) + return te.compute( + (num_blocks * bs_r, m), + lambda m, n: bsrmm_block[idxd(m, bs_r), idxm(m, bs_r), n], + tag="sparse_dense_bsrmm_v1", + ) + + +def _sparse_dense_bsrmm_v2(data, weight_data, weight_indices, weight_indptr): (m, _) = get_const_tuple(data.shape) (_, bs_r, bs_c) = get_const_tuple(weight_data.shape) (num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape) @@ -95,11 +217,13 @@ def _compute_block(i, nb_j, j): idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod - bsrmm_block = te.compute((m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block") + bsrmm_block = te.compute( + (m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block_v2" + ) return te.compute( (m, num_blocks * bs_r), lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)], - tag="sparse_dense_bsrmm", + tag="sparse_dense_bsrmm_v2", ) diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 09dca09a82de..e9073730641d 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -39,44 +39,76 @@ TVM_REGISTER_NODE_TYPE(SparseDenseAttrs); bool SparseDenseRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 5); - const auto* data = types[0].as(); - const auto* weight_data = types[1].as(); - ICHECK(weight_data->shape.size() == 1 || weight_data->shape.size() == 3); - const auto* weight_indptr = types[3].as(); - if (data == nullptr) return false; - - if (weight_data->shape.size() == 1) { - // CSR case. - Array oshape({data->shape[0], weight_indptr->shape[0] - 1}); - reporter->Assign(types[4], TensorType(oshape, data->dtype)); - return true; + const auto* param = attrs.as(); + ICHECK(param != nullptr); + + if (param->sparse_lhs) { + const auto* weight = types[0].as(); + const auto* data_data = types[1].as(); + ICHECK(data_data->shape.size() == 1 || data_data->shape.size() == 3); + const auto* data_indptr = types[3].as(); + if (weight == nullptr) return false; + + if (data_data->shape.size() == 1) { + // CSR case. + Array oshape({data_indptr->shape[0] - 1, weight->shape[0]}); + reporter->Assign(types[4], TensorType(oshape, weight->dtype)); + return true; + } + + if (data_data->shape.size() == 3) { + // BSR case. + Array oshape( + {(data_indptr->shape[0] - 1) * data_data->shape[1], weight->shape[0]}); + reporter->Assign(types[4], TensorType(oshape, weight->dtype)); + return true; + } + LOG(FATAL) << "Unknown data ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)"; + return false; + + } else { + const auto* data = types[0].as(); + const auto* weight_data = types[1].as(); + ICHECK(weight_data->shape.size() == 1 || weight_data->shape.size() == 3); + const auto* weight_indptr = types[3].as(); + if (data == nullptr) return false; + + if (weight_data->shape.size() == 1) { + // CSR case. + Array oshape({data->shape[0], weight_indptr->shape[0] - 1}); + reporter->Assign(types[4], TensorType(oshape, data->dtype)); + return true; + } + + if (weight_data->shape.size() == 3) { + // BSR case. + Array oshape( + {data->shape[0], (weight_indptr->shape[0] - 1) * weight_data->shape[1]}); + reporter->Assign(types[4], TensorType(oshape, data->dtype)); + return true; + } + LOG(FATAL) << "Unknown weight ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)"; + return false; } - - if (weight_data->shape.size() == 3) { - // BSR case. - Array oshape( - {data->shape[0], (weight_indptr->shape[0] - 1) * weight_data->shape[1]}); - reporter->Assign(types[4], TensorType(oshape, data->dtype)); - return true; - } - LOG(FATAL) << "Unknown weight ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)"; - return false; } // Positional relay function to create dense operator used by frontend FFI. -Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr) { +Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr, + bool sparse_lhs) { auto attrs = make_object(); + attrs->sparse_lhs = std::move(sparse_lhs); static const Op& op = Op::Get("nn.sparse_dense"); return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense") .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeSparseDense, args, rv); + runtime::detail::unpack_call(MakeSparseDense, args, rv); }); RELAY_REGISTER_OP("nn.sparse_dense") - .describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with W sparse. + .describe( + R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with either X or W sparse. - **data**: `(x1, x2, ..., xn, input_dim)` - **weight**: `(units, input_dim)` @@ -85,10 +117,10 @@ RELAY_REGISTER_OP("nn.sparse_dense") )code" TVM_ADD_FILELINE) .set_attrs_type() .set_num_inputs(4) - .add_argument("data", "nD Tensor", "Input data.") - .add_argument("weight_data", "1D Tensor", "Weight data matrix.") - .add_argument("weight_indices", "1D Tensor", "Weight indices matrix.") - .add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.") + .add_argument("dense_data", "nD Tensor", "Input dense data.") + .add_argument("sparse_data", "1D or 3D Tensor", "Sparse data matrix.") + .add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.") + .add_argument("sparse_indptr", "1D Tensor", "Sparse indptr matrix.") .set_support_level(1) .add_type_rel("SparseDense", SparseDenseRel); diff --git a/src/relay/transforms/convert_sparse_dense.cc b/src/relay/transforms/convert_sparse_dense.cc index 5f4dbe642c3d..26a4d487196d 100644 --- a/src/relay/transforms/convert_sparse_dense.cc +++ b/src/relay/transforms/convert_sparse_dense.cc @@ -103,8 +103,10 @@ class DenseToSparseDenseMutator : public ExprRewriter { Var weight_data(prefix + ".data", ws_data_type); Var weight_indices(prefix + ".indices", ws_indices_type); Var weight_indptr(prefix + ".indptr", ws_indptr_type); + auto attrs = make_object(); - return Call(sparse_dense_op_, {data, weight_data, weight_indices, weight_indptr}); + return Call(sparse_dense_op_, {data, weight_data, weight_indices, weight_indptr}, + Attrs(attrs)); } } } diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index b79bd8bbba52..6720c2e13bfe 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1794,9 +1794,11 @@ def test_forward_sparse_dense_matmul(): # # ------------------------------------------------------------------ - # TODO(ANSHUMAN87): False case for flip need to be supported - # _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 4], [4, 3], "float32") - _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 5], [4, 3], "float32", True) + _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 4], [4, 3], "float32") + _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 3], [3, 3], "float32") + _test_sparse_dense_matmul([[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [5, 5], [5, 5], "float32") + _test_sparse_dense_matmul([[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [7, 9], [9, 5], "float32") + _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [4, 3], [3, 4], "float32", True) _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 3], [3, 3], "float32", True) _test_sparse_dense_matmul( [[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [5, 5], [5, 5], "float32", True diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index 62f49e21418f..e47bfddbf7fc 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -272,6 +272,31 @@ def test_sparse_dense_csr(): tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) +def test_sparse_dense_csr_reverse(): + M, N, K, density = 1, 17, 47, 0.2 + X_np = np.random.randn(M, K).astype("float32") + W_sp_np = sp.random(N, K, density=density, format="csr", dtype="float32") + W_np = W_sp_np.todense() + Y_np = W_np.dot(X_np.T) + + W_data = te.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype)) + W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype)) + W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype)) + X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype)) + Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr, sparse_lhs=True) + s = te.create_schedule(Y.op) + func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype)) + func( + tvm.nd.array(X_np), + tvm.nd.array(W_sp_np.data), + tvm.nd.array(W_sp_np.indices), + tvm.nd.array(W_sp_np.indptr), + Y_tvm, + ) + tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) + + def test_sparse_transpose_csr(): N, density = 1023, 0.3 @@ -368,6 +393,31 @@ def test_sparse_dense_bsr_relu(ctx, target): verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, False, ctx, target) +def test_sparse_dense_bsr_reverse(): + M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9 + X_np = np.random.randn(M, K).astype("float32") + W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32") + W_np = W_sp_np.todense() + Y_np = W_np.dot(X_np.T) + + W_data = te.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype)) + W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype)) + W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype)) + X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype)) + Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr, sparse_lhs=True) + s = te.create_schedule(Y.op) + func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype)) + func( + tvm.nd.array(X_np), + tvm.nd.array(W_sp_np.data), + tvm.nd.array(W_sp_np.indices), + tvm.nd.array(W_sp_np.indptr), + Y_tvm, + ) + tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) + + @tvm.testing.uses_gpu def test_sparse_dense_bsr_randomized(): for _ in range(20): @@ -480,3 +530,5 @@ def test_sparse_dense_padded_alter_op(): test_sparse_transpose_csr() test_sparse_dense_padded_cuda() test_sparse_dense_padded_alter_op() + test_sparse_dense_csr_reverse() + test_sparse_dense_bsr_reverse()