Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,51 @@ def _impl(inputs, attr, params, mod):
return _impl


def _sparse_tensor_dense_matmul():
# Sparse utility from scipy
from scipy.sparse import csr_matrix

def _impl(inputs, attr, params, mod):
assert len(inputs) == 4, "There should be 4 input tensors"

indices_tensor = _infer_value(inputs[0], params, mod).asnumpy()
values_tensor = _infer_value(inputs[1], params, mod).asnumpy()
dense_shape_tensor = _infer_value(inputs[2], params, mod).asnumpy()

data = inputs[3]

rows = [x[0] for x in indices_tensor]
cols = [x[1] for x in indices_tensor]

# Create scipy sparse Tensor(CSR)
weight_sp = csr_matrix(
(values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you swap rows and columns here you can avoid the sparse transpose below. This probably isn't much of a performance hit except for large matrices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree with your comment, but as this PR is just an initial PR, i feel lets keep the code as it is, it provides better readability in terms of steps involved, later on, once all features are merged, we can work together to optimize it 🙂
Please let me know, in case you think otherwise!

)
weight_sp = csr_matrix(weight_sp.transpose())

weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype)
weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype)

ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs])

# If both are true means First input was dense and second was sparse
# TODO(ANSHUMAN87): Support other adjoint option too
if attr.get("adjoint_a") and attr.get("adjoint_b"):
ret = _op.transpose(ret)
else:
raise tvm.error.OpAttributeUnImplemented(
"Only tf.sparse.sparse_dense_matmul() with adjoint_a=True and adjoint_b=True"
" is supported, but adjoint_a={} and adjoint_b={} was supplied.".format(
attr.get("adjoint_a"), attr.get("adjoint_b")
)
)

return ret

return _impl


def _identity():
def _impl(inputs, attr, params, mod):
return inputs[0]
Expand Down Expand Up @@ -2407,6 +2452,7 @@ def _impl(inputs, attr, params, mod):
"SpaceToBatchND": _space_to_batch_nd(),
"SpaceToDepth": _space_to_depth(),
"SparseToDense": _sparse_to_dense(),
"SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,7 +2046,7 @@ def sparse_transpose(x):

Parameters
----------
x : namedtuple.
x : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
The sparse weight matrix for the fast matrix transpose.

Returns
Expand All @@ -2055,7 +2055,9 @@ def sparse_transpose(x):
Tuple of output sparse tensor (same shape and format as input),
i.e. if CSR then output is in ([data, indices, indptr]) form
"""
return expr.TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3)
if hasattr(x, "indices"):
return expr.TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3)
return expr.TupleWrapper(_make.sparse_transpose(x[0], x[1], x[2]), 3)


def contrib_conv2d_winograd_without_weight_transform(
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/topi/cuda/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def gen_ir(data, w_data, w_indices, w_indptr, out):
assert (
mb >= mi
), "Number of block rows in dense matrix must be larger than warp size: {} vs {}.".format(
warp_size, m
warp_size, mb
)
mo = ceil_div(mb, mi)
ni = 1 # TODO(tkonolige): how do I compute the number of warps per block?
Expand Down Expand Up @@ -367,9 +367,14 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
and isinstance(inputs[2], relay.Constant)
and isinstance(inputs[3], relay.Constant)
):
sparse_matrix = sp.bsr_matrix(
(inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy())
)
if len(inputs[1].data.asnumpy().shape) == 1:
sparse_matrix = sp.csr_matrix(
(inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy())
).tobsr()
else:
sparse_matrix = sp.bsr_matrix(
(inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy())
)
warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
sparse_matrix = pad_sparse_matrix(sparse_matrix, warp_size)
return relay.nn._make.sparse_dense_padded(
Expand Down
56 changes: 56 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,6 +1750,62 @@ def test_forward_batch_matmul():
_test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True)


#######################################################################
# SparseTensorDenseMatMul
# ----------------------------------


def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=False):
""" One iteration of sparse_dense_matmul """

# TODO(ANSHUMAN87): Support adjoint options too
for adjoint_a in [False]:
for adjoint_b in [False]:
with tf.Graph().as_default():
A_sp = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=A_shape)
B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")

if flip:
result = tf.sparse.sparse_dense_matmul(
B, A_sp, adjoint_a=adjoint_a, adjoint_b=adjoint_b
)
else:
result = tf.sparse.sparse_dense_matmul(
A_sp, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b
)

B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)

# TODO(ANSHUMAN87): There is an issue in cuda scheduling for csr, work in progress
compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True)


def test_forward_sparse_dense_matmul():
""" sparse_dense_matmul op test"""
###################################################################
#
# In order to create a SparseTensor, it requires 3 input as below:
# SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
#
# Above Sparse can be represented in Dense as below :
# [[1, 0, 0, 0]
# [0, 0, 2, 0]
# [0, 0, 0, 0]]
#
# ------------------------------------------------------------------

# TODO(ANSHUMAN87): False case for flip need to be supported
# _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 4], [4, 3], "float32")
_test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 5], [4, 3], "float32", True)
_test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 3], [3, 3], "float32", True)
_test_sparse_dense_matmul(
[[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [5, 5], [5, 5], "float32", True
)
_test_sparse_dense_matmul(
[[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [9, 5], [7, 9], "float32", True
)


#######################################################################
# StridedSlice
# ------------
Expand Down