Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,15 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {

/*! \brief Attributes for sparse_dense operator */
struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {}
bool sparse_lhs;

TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {
TVM_ATTR_FIELD(sparse_lhs)
.set_default(false)
.describe(
"Indicate whether sparse matrix is multiplied on the right or the left. If true, then "
"the operation is S * D^T (D dense, S sparse). If false, the operation is D * S^T");
}
};

/*! \brief Attributes for sparse_transpose operator */
Expand Down
29 changes: 23 additions & 6 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,28 +916,45 @@ def _impl(inputs, attr, params, mod):

data = inputs[3]

# By default, in tensorflow the first input ,i.e., data is sparse
sparse_lhs = True

# If both are true means First input was dense and second was sparse
if attr.get("adjoint_a") and attr.get("adjoint_b"):
sparse_lhs = False

rows = [x[0] for x in indices_tensor]
cols = [x[1] for x in indices_tensor]

# Create scipy sparse Tensor(CSR)
weight_sp = csr_matrix(
(values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())
)
weight_sp = csr_matrix(weight_sp.transpose())

if sparse_lhs:
data = _op.transpose(data)
else:
weight_sp = csr_matrix(weight_sp.transpose())

weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype)
weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype)

ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs])
ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs], sparse_lhs)

# If both are true means First input was dense and second was sparse
# TODO(ANSHUMAN87): Support other adjoint option too
if attr.get("adjoint_a") and attr.get("adjoint_b"):
if not sparse_lhs:
ret = _op.transpose(ret)
else:

# Case 1. If both are true means first input was dense and second was sparse
# Case 2. If both are false means first input was sparse and second was dense
# TODO(ANSHUMAN87): Support other adjoint option too
if not (
(attr.get("adjoint_a") and attr.get("adjoint_b"))
or ((not attr.get("adjoint_a")) and (not attr.get("adjoint_b")))
):
raise tvm.error.OpAttributeUnImplemented(
"Only tf.sparse.sparse_dense_matmul() with adjoint_a=True and adjoint_b=True"
"or with adjoint_a=False and adjoint_b=False"
" is supported, but adjoint_a={} and adjoint_b={} was supplied.".format(
attr.get("adjoint_a"), attr.get("adjoint_b")
)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def compute_fifo_buffer(attrs, inputs, out_type):
@reg.register_compute("nn.sparse_dense")
def compute_sparse_dense(attrs, inputs, out_type):
"""Compute definition of sparse_dense"""
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3], attrs["sparse_lhs"])]


reg.register_strategy("nn.sparse_dense", strategy.sparse_dense_strategy)
Expand Down
44 changes: 31 additions & 13 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,17 +1993,27 @@ def batch_matmul(x, y):
return _make.batch_matmul(x, y)


def sparse_dense(data, weight):
# pylint: disable=no-else-return,inconsistent-return-statements
def sparse_dense(dense_mat, sparse_mat, sparse_lhs=False):
r"""
Computes the matrix multiplication of `data` and `weight`, where `data` is
a dense matrix and `weight` is a sparse (either BSR or CSR) namedtuple with
Computes the matrix multiplication of `dense_mat` and `sparse_mat`, where `dense_mat` is
a dense matrix and `sparse_mat` is a sparse (either BSR or CSR) namedtuple with
fields `data`, `indices`, and `indptr`.

.. math::
\if sparse_lhs=False:
.. math::

\mbox{sparse_dense}(dense_mat, sparse_mat)[m, n]
= \mbox{matmul}(D, \mbox{as_dense}(S)^T)[m, n]

\mbox{sparse_dense}(data, weight)[m, n] = \mbox{matmul}(x, \mbox{as_dense}(weight)^T)[m, n]
\if sparse_lhs=True:
.. math::

where `as_dense` returns dense equivalent of the given sparse matrix.
\mbox{sparse_dense}(dense_mat, sparse_mat)[m, n]
= \mbox{matmul}(\mbox{as_dense}(S), (D)^T)[m, n]

where `as_dense` returns dense equivalent of the given S(sparse matrix)
while performing matmul with given D(dense matrix).

See
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html
Expand All @@ -2013,20 +2023,28 @@ def sparse_dense(data, weight):

Parameters
----------
data : tvm.relay.Expr
The input data for the matrix multiplication
dense_mat : tvm.relay.Expr
The input dense matrix for the matrix multiplication

weight : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
The sparse weight matrix for the matrix multiplication.
sparse_mat : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
The input sparse matrix for the matrix multiplication.

sparse_lhs : bool, optional
Indicates whether lhs or rhs matrix is sparse. Default value is False.

Returns
-------
result: tvm.relay.Expr
The computed result.
"""
if hasattr(weight, "indices"):
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
return _make.sparse_dense(data, weight[0], weight[1], weight[2])
if hasattr(sparse_mat, "indices"):
return _make.sparse_dense(
dense_mat, sparse_mat.data, sparse_mat.indices, sparse_mat.indptr, sparse_lhs
)
else:
return _make.sparse_dense(
dense_mat, sparse_mat[0], sparse_mat[1], sparse_mat[2], sparse_lhs
)


def sparse_transpose(x):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def wrap_compute_sparse_dense(topi_compute):
"""wrap sparse dense topi compute"""

def _compute_sparse_dense(attrs, inputs, out_type):
return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])]
return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3], attrs["sparse_lhs"])]

return _compute_sparse_dense

Expand Down
6 changes: 4 additions & 2 deletions python/tvm/topi/cuda/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ def schedule_sparse_dense(outs):
# pylint:disable=invalid-name
s = te.create_schedule([x.op for x in outs])

# TODO(ANSHUMAN87): Add for sparse_dense_bsrmm_v1 also
def _callback(op):
if op.tag == "sparse_dense_bsrmm":
if op.tag == "sparse_dense_bsrmm_v2":
y_bsrmm = op.input_tensors[0]
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v2"
out = s.outputs[0].output(0)

if op not in s.outputs:
Expand Down Expand Up @@ -362,6 +363,7 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
sparse_dense implementation for one that operates on a padded matrix. We
also padd the matrix.
"""
# TODO(ANSHUMAN87): Handle for sparse_lhs case too
if (
isinstance(inputs[1], relay.Constant)
and isinstance(inputs[2], relay.Constant)
Expand Down
140 changes: 132 additions & 8 deletions python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..utils import get_const_tuple


def sparse_dense(data, weight_data, weight_indices, weight_indptr):
def sparse_dense_v2(data, weight_data, weight_indices, weight_indptr):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
Expand Down Expand Up @@ -52,13 +52,104 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr):
"""
assert len(weight_data.shape) in (1, 3)
if len(weight_data.shape) == 1:
func = _sparse_dense_csrmm
func = _sparse_dense_csrmm_v2
if len(weight_data.shape) == 3:
func = _sparse_dense_bsrmm
func = _sparse_dense_bsrmm_v2
return func(data, weight_data, weight_indices, weight_indptr)


def _sparse_dense_csrmm(data, weight_data, weight_indices, weight_indptr):
def sparse_dense_v1(data_data, data_indices, data_indptr, weight):
"""
Computes sparse-dense matrix multiplication of
`(data_data, data_indices, data_indptr)` and `weight.T`

Parameters
----------
data_data:
1-D with shape [nnz] (CSR) or
3-D with shape [num_blocks, bs_r, bs_c] (BSR)

data_indices:
1-D with shape [nnz] (CSR) or
1-D with shape [num_blocks] (BSR)

data_indptr:
1-D with shape [M + 1] (CSR) or
1-D with shape [(M + 1) // bs_r] (BSR)

weight:
2-D with shape [N, K], float32

Returns
-------
output : tvm.te.Tensor
2-D with shape [M, N]
"""
assert len(data_data.shape) in (1, 3)
if len(data_data.shape) == 1:
func = _sparse_dense_csrmm_v1
if len(data_data.shape) == 3:
func = _sparse_dense_bsrmm_v1
return func(data_data, data_indices, data_indptr, weight)


# pylint: disable=no-else-return,inconsistent-return-statements
def sparse_dense(dense_data, sparse_data, sparse_indices, sparse_indptr, sparse_lhs=False):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`, if sparse_lhs=False
or
Computes sparse-dense matrix multiplication of
`(data_data, data_indices, data_indptr)` and `weight.T`, if sparse_lhs=True

Parameters
----------
dense_data : tvm.te.Tensor
2-D with shape [M, K], float32

sparse_data : tvm.te.Tensor
1-D with shape [nnz] (CSR) or
3-D with shape [num_blocks, bs_r, bs_c] (BSR)

sparse_indices : tvm.te.Tensor
1-D with shape [nnz] (CSR) or
1-D with shape [num_blocks] (BSR)

sparse_indptr : tvm.te.Tensor
1-D with shape [N + 1] (CSR) or
1-D with shape [(N + 1) // bs_r] (BSR)

sparse_lhs : bool, optional
Indicates whether lhs or rhs matrix is sparse. Default value is False.

Returns
-------
output : tvm.te.Tensor
2-D with shape [M, N]
"""
if sparse_lhs:
return sparse_dense_v1(sparse_data, sparse_indices, sparse_indptr, dense_data)
else:
return sparse_dense_v2(dense_data, sparse_data, sparse_indices, sparse_indptr)


def _sparse_dense_csrmm_v1(data_data, data_indices, data_indptr, weight):
oshape = (get_const_tuple(data_indptr.shape)[0] - 1, get_const_tuple(weight.shape)[0])

def f(row, i):
row_start = data_indptr[row]
row_end = data_indptr[row + 1]
row_elems = row_end - row_start
elem_idx = te.reduce_axis((0, row_elems), name="elem_idx")
elem = row_start + elem_idx
a_val = data_data[elem]
weight_val = weight[i, data_indices[elem]]
return te.sum(a_val * weight_val, axis=elem_idx)

return te.compute(oshape, f, tag="sparse_dense_csrmm_v1")


def _sparse_dense_csrmm_v2(data, weight_data, weight_indices, weight_indptr):
oshape = (get_const_tuple(data.shape)[0], get_const_tuple(weight_indptr.shape)[0] - 1)

def f(i, row):
Expand All @@ -71,10 +162,41 @@ def f(i, row):
weight_val = data[i, weight_indices[elem]]
return te.sum(a_val * weight_val, axis=elem_idx)

return te.compute(oshape, f, tag="sparse_dense_csrmm")
return te.compute(oshape, f, tag="sparse_dense_csrmm_v2")


def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
def _sparse_dense_bsrmm_v1(data_data, data_indices, data_indptr, weight):
(m, _) = get_const_tuple(weight.shape)
(_, bs_r, bs_c) = get_const_tuple(data_data.shape)
(num_blocks_plus_1,) = get_const_tuple(data_indptr.shape)
num_blocks = num_blocks_plus_1 - 1

def _compute_block(nb_j, j, i):
row_start = data_indptr[nb_j]
row_end = data_indptr[nb_j + 1]
row_elems = row_end - row_start
elem_idx = te.reduce_axis((0, row_elems), name="elem_idx")
block_offset = row_start + elem_idx
c = te.reduce_axis((0, bs_c), name="c")
block_j = data_indices[block_offset]
block_ij_val = data_data[block_offset][j][c]
x_val = weight[i, bs_c * block_j + c]
return te.sum(block_ij_val * x_val, axis=[elem_idx, c])

idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod

bsrmm_block = te.compute(
(num_blocks, bs_r, m), _compute_block, tag="sparse_dense_bsrmm_block_v1"
)
return te.compute(
(num_blocks * bs_r, m),
lambda m, n: bsrmm_block[idxd(m, bs_r), idxm(m, bs_r), n],
tag="sparse_dense_bsrmm_v1",
)


def _sparse_dense_bsrmm_v2(data, weight_data, weight_indices, weight_indptr):
(m, _) = get_const_tuple(data.shape)
(_, bs_r, bs_c) = get_const_tuple(weight_data.shape)
(num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape)
Expand All @@ -95,11 +217,13 @@ def _compute_block(i, nb_j, j):
idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod

bsrmm_block = te.compute((m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block")
bsrmm_block = te.compute(
(m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block_v2"
)
return te.compute(
(m, num_blocks * bs_r),
lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)],
tag="sparse_dense_bsrmm",
tag="sparse_dense_bsrmm_v2",
)


Expand Down
Loading