From a6f4f7b88de531c375ff957648c371b60918302e Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Wed, 14 Oct 2020 23:54:20 +0530 Subject: [PATCH 1/4] [Relay][Frontend] SparseTensorDenseMatMul support for Tensorflow --- python/tvm/relay/frontend/tensorflow.py | 35 ++++++++++++++ python/tvm/relay/op/nn/nn.py | 6 ++- python/tvm/topi/cuda/sparse.py | 13 ++++-- .../frontend/tensorflow/test_forward.py | 46 +++++++++++++++++++ 4 files changed, 94 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 2c7adf03bad8..38decfac7eff 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -902,6 +902,40 @@ def _impl(inputs, attr, params, mod): return _impl +def _sparse_tensor_dense_matmul(): + # Sparse utility from Numpy + from scipy import sparse + + def _impl(inputs, attr, params, mod): + assert len(inputs) == 4, "There should be 4 input tensors" + + indices_tensor = _infer_value(inputs[0], params, mod).asnumpy() + values_tensor = _infer_value(inputs[1], params, mod).asnumpy() + dense_shape_tensor = _infer_value(inputs[2], params, mod).asnumpy() + + data = inputs[3] + + rows = [x[0] for x in indices_tensor] + cols = [x[1] for x in indices_tensor] + + # Create Numpy sparse Tensor(CSR) + weight_sp = sparse.csr_matrix((values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())) + weight_sp = sparse.csr_matrix(weight_sp.transpose()) + + weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype) + weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype) + weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype) + + ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs]) + + # If both are true means First input was dense and second was sparse + # TODO: Support other adjoint option too + if attr.get("adjoint_a") and attr.get("adjoint_b"): + ret = _op.transpose(ret) + + return ret + + return _impl def _identity(): def _impl(inputs, attr, params, mod): @@ -2407,6 +2441,7 @@ def _impl(inputs, attr, params, mod): "SpaceToBatchND": _space_to_batch_nd(), "SpaceToDepth": _space_to_depth(), "SparseToDense": _sparse_to_dense(), + "SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(), "Split": _split(False), "SplitV": _split(True), "Sqrt": AttrCvt("sqrt"), diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 0d012540343f..4810bdc35bbd 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2046,7 +2046,7 @@ def sparse_transpose(x): Parameters ---------- - x : namedtuple. + x : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]]. The sparse weight matrix for the fast matrix transpose. Returns @@ -2055,7 +2055,9 @@ def sparse_transpose(x): Tuple of output sparse tensor (same shape and format as input), i.e. if CSR then output is in ([data, indices, indptr]) form """ - return expr.TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3) + if hasattr(x, "indices"): + return expr.TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3) + return expr.TupleWrapper(_make.sparse_transpose(x[0], x[1], x[2]), 3) def contrib_conv2d_winograd_without_weight_transform( diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index d125423968a9..aec6e824b958 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -180,7 +180,7 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): assert ( mb >= mi ), "Number of block rows in dense matrix must be larger than warp size: {} vs {}.".format( - warp_size, m + warp_size, mb ) mo = ceil_div(mb, mi) ni = 1 # TODO(tkonolige): how do I compute the number of warps per block? @@ -367,9 +367,14 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type): and isinstance(inputs[2], relay.Constant) and isinstance(inputs[3], relay.Constant) ): - sparse_matrix = sp.bsr_matrix( - (inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy()) - ) + if len(inputs[1].data.asnumpy().shape) == 1: + sparse_matrix = sp.csr_matrix( + (inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy()) + ).tobsr() + else : + sparse_matrix = sp.bsr_matrix( + (inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy()) + ) warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) sparse_matrix = pad_sparse_matrix(sparse_matrix, warp_size) return relay.nn._make.sparse_dense_padded( diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 5ec4562543c2..7b3e4516fb3b 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1749,6 +1749,52 @@ def test_forward_batch_matmul(): _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), "int32", True, False) _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True) +####################################################################### +# SparseTensorDenseMatMul +# ---------------------------------- + + +def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=False): + """ One iteration of sparse_dense_matmul """ + + #TODO: Support adjoint options too + for adjoint_a in [False]: + for adjoint_b in [False]: + with tf.Graph().as_default(): + A_sp = tf.sparse.SparseTensor(indices=[[0, 0], [1, 2]], values=[4., 8.], dense_shape=A_shape) + B = tf.placeholder(shape=B_shape, dtype=dtype, name="B") + + if flip: + result = tf.sparse.sparse_dense_matmul(B, A_sp, adjoint_a=adjoint_a, adjoint_b=adjoint_b) + else: + result = tf.sparse.sparse_dense_matmul(A_sp, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b) + + B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) + + #TODO: There is an issue in cuda scheduling for csr, work in progress + compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True) + +def test_forward_sparse_dense_matmul(): + """ sparse_dense_matmul op test""" + ################################################################### + # + # In order to create a SparseTensor, it requires 3 input as below: + # SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) + # + # Above Sparse can be represented in Dense as below: + # [[1, 0, 0, 0] + # [0, 0, 2, 0] + # [0, 0, 0, 0]] + # + #------------------------------------------------------------------ + + #TODO: False case for flip need to be supported + #_test_sparse_dense_matmul([[0, 0], [1, 2]], [4., 8.], [3, 4], [4, 3], "float32") + _test_sparse_dense_matmul([[0, 0], [1, 2]], [4., 8.], [3, 5], [4, 3], "float32", True) + _test_sparse_dense_matmul([[0, 0], [1, 2]], [4., 8.], [3, 3], [3, 3], "float32", True) + _test_sparse_dense_matmul([[0, 0], [1, 3], [4, 3]], [3., 6., 9.], [5, 5], [5, 5], "float32", True) + _test_sparse_dense_matmul([[0, 0], [1, 3], [4, 3]], [3., 6., 9.], [9, 5], [7, 9], "float32", True) + ####################################################################### # StridedSlice From 26d089d56cf75dee4a96f2123bf911a1cb4a73a6 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Thu, 15 Oct 2020 00:29:13 +0530 Subject: [PATCH 2/4] Lint error resolved --- python/tvm/relay/frontend/tensorflow.py | 6 ++- python/tvm/topi/cuda/sparse.py | 10 ++--- .../frontend/tensorflow/test_forward.py | 40 ++++++++++++------- 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 38decfac7eff..44106b1b2a34 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -902,6 +902,7 @@ def _impl(inputs, attr, params, mod): return _impl + def _sparse_tensor_dense_matmul(): # Sparse utility from Numpy from scipy import sparse @@ -919,7 +920,9 @@ def _impl(inputs, attr, params, mod): cols = [x[1] for x in indices_tensor] # Create Numpy sparse Tensor(CSR) - weight_sp = sparse.csr_matrix((values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())) + weight_sp = sparse.csr_matrix( + (values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist()) + ) weight_sp = sparse.csr_matrix(weight_sp.transpose()) weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype) @@ -937,6 +940,7 @@ def _impl(inputs, attr, params, mod): return _impl + def _identity(): def _impl(inputs, attr, params, mod): return inputs[0] diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index aec6e824b958..ebac5517d46c 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -369,12 +369,12 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type): ): if len(inputs[1].data.asnumpy().shape) == 1: sparse_matrix = sp.csr_matrix( - (inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy()) - ).tobsr() - else : + (inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy()) + ).tobsr() + else: sparse_matrix = sp.bsr_matrix( - (inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy()) - ) + (inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy()) + ) warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) sparse_matrix = pad_sparse_matrix(sparse_matrix, warp_size) return relay.nn._make.sparse_dense_padded( diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 7b3e4516fb3b..7755214b9fa6 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1749,6 +1749,7 @@ def test_forward_batch_matmul(): _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), "int32", True, False) _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True) + ####################################################################### # SparseTensorDenseMatMul # ---------------------------------- @@ -1757,23 +1758,30 @@ def test_forward_batch_matmul(): def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=False): """ One iteration of sparse_dense_matmul """ - #TODO: Support adjoint options too + # TODO: Support adjoint options too for adjoint_a in [False]: for adjoint_b in [False]: with tf.Graph().as_default(): - A_sp = tf.sparse.SparseTensor(indices=[[0, 0], [1, 2]], values=[4., 8.], dense_shape=A_shape) + A_sp = tf.sparse.SparseTensor( + indices=[[0, 0], [1, 2]], values=[4.0, 8.0], dense_shape=A_shape + ) B = tf.placeholder(shape=B_shape, dtype=dtype, name="B") if flip: - result = tf.sparse.sparse_dense_matmul(B, A_sp, adjoint_a=adjoint_a, adjoint_b=adjoint_b) + result = tf.sparse.sparse_dense_matmul( + B, A_sp, adjoint_a=adjoint_a, adjoint_b=adjoint_b + ) else: - result = tf.sparse.sparse_dense_matmul(A_sp, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b) + result = tf.sparse.sparse_dense_matmul( + A_sp, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b + ) B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) - #TODO: There is an issue in cuda scheduling for csr, work in progress + # TODO: There is an issue in cuda scheduling for csr, work in progress compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True) + def test_forward_sparse_dense_matmul(): """ sparse_dense_matmul op test""" ################################################################### @@ -1781,19 +1789,23 @@ def test_forward_sparse_dense_matmul(): # In order to create a SparseTensor, it requires 3 input as below: # SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) # - # Above Sparse can be represented in Dense as below: + # Above Sparse can be represented in Dense as below : # [[1, 0, 0, 0] # [0, 0, 2, 0] # [0, 0, 0, 0]] # - #------------------------------------------------------------------ - - #TODO: False case for flip need to be supported - #_test_sparse_dense_matmul([[0, 0], [1, 2]], [4., 8.], [3, 4], [4, 3], "float32") - _test_sparse_dense_matmul([[0, 0], [1, 2]], [4., 8.], [3, 5], [4, 3], "float32", True) - _test_sparse_dense_matmul([[0, 0], [1, 2]], [4., 8.], [3, 3], [3, 3], "float32", True) - _test_sparse_dense_matmul([[0, 0], [1, 3], [4, 3]], [3., 6., 9.], [5, 5], [5, 5], "float32", True) - _test_sparse_dense_matmul([[0, 0], [1, 3], [4, 3]], [3., 6., 9.], [9, 5], [7, 9], "float32", True) + # ------------------------------------------------------------------ + + # TODO: False case for flip need to be supported + # _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 4], [4, 3], "float32") + _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 5], [4, 3], "float32", True) + _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 3], [3, 3], "float32", True) + _test_sparse_dense_matmul( + [[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [5, 5], [5, 5], "float32", True + ) + _test_sparse_dense_matmul( + [[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [9, 5], [7, 9], "float32", True + ) ####################################################################### From 110f577c1388f855a5fb21589e2306d670bafed0 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Sun, 18 Oct 2020 20:47:59 +0530 Subject: [PATCH 3/4] [1] Review comments handled --- python/tvm/relay/frontend/tensorflow.py | 12 +++++++----- tests/python/frontend/tensorflow/test_forward.py | 4 +--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 44106b1b2a34..ea59797b093b 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -904,8 +904,8 @@ def _impl(inputs, attr, params, mod): def _sparse_tensor_dense_matmul(): - # Sparse utility from Numpy - from scipy import sparse + # Sparse utility from scipy + from scipy.sparse import csr_matrix def _impl(inputs, attr, params, mod): assert len(inputs) == 4, "There should be 4 input tensors" @@ -919,11 +919,11 @@ def _impl(inputs, attr, params, mod): rows = [x[0] for x in indices_tensor] cols = [x[1] for x in indices_tensor] - # Create Numpy sparse Tensor(CSR) - weight_sp = sparse.csr_matrix( + # Create scipy sparse Tensor(CSR) + weight_sp = csr_matrix( (values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist()) ) - weight_sp = sparse.csr_matrix(weight_sp.transpose()) + weight_sp = csr_matrix(weight_sp.transpose()) weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype) weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype) @@ -935,6 +935,8 @@ def _impl(inputs, attr, params, mod): # TODO: Support other adjoint option too if attr.get("adjoint_a") and attr.get("adjoint_b"): ret = _op.transpose(ret) + else: + raise tvm.error.OpAttributeUnImplemented("Adjoint option is not supported yet.") return ret diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 7755214b9fa6..1f3b8bc58052 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1762,9 +1762,7 @@ def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=Fal for adjoint_a in [False]: for adjoint_b in [False]: with tf.Graph().as_default(): - A_sp = tf.sparse.SparseTensor( - indices=[[0, 0], [1, 2]], values=[4.0, 8.0], dense_shape=A_shape - ) + A_sp = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=A_shape) B = tf.placeholder(shape=B_shape, dtype=dtype, name="B") if flip: From 2f63ba8e31b10c4765fac7455e7dc070933ed5e6 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Tue, 3 Nov 2020 01:22:15 +0530 Subject: [PATCH 4/4] [2] Review comments handled --- python/tvm/relay/frontend/tensorflow.py | 9 +++++++-- tests/python/frontend/tensorflow/test_forward.py | 6 +++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index ea59797b093b..39f8f33b98dd 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -932,11 +932,16 @@ def _impl(inputs, attr, params, mod): ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs]) # If both are true means First input was dense and second was sparse - # TODO: Support other adjoint option too + # TODO(ANSHUMAN87): Support other adjoint option too if attr.get("adjoint_a") and attr.get("adjoint_b"): ret = _op.transpose(ret) else: - raise tvm.error.OpAttributeUnImplemented("Adjoint option is not supported yet.") + raise tvm.error.OpAttributeUnImplemented( + "Only tf.sparse.sparse_dense_matmul() with adjoint_a=True and adjoint_b=True" + " is supported, but adjoint_a={} and adjoint_b={} was supplied.".format( + attr.get("adjoint_a"), attr.get("adjoint_b") + ) + ) return ret diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 1f3b8bc58052..5e8dcf38fa7a 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1758,7 +1758,7 @@ def test_forward_batch_matmul(): def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=False): """ One iteration of sparse_dense_matmul """ - # TODO: Support adjoint options too + # TODO(ANSHUMAN87): Support adjoint options too for adjoint_a in [False]: for adjoint_b in [False]: with tf.Graph().as_default(): @@ -1776,7 +1776,7 @@ def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=Fal B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) - # TODO: There is an issue in cuda scheduling for csr, work in progress + # TODO(ANSHUMAN87): There is an issue in cuda scheduling for csr, work in progress compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True) @@ -1794,7 +1794,7 @@ def test_forward_sparse_dense_matmul(): # # ------------------------------------------------------------------ - # TODO: False case for flip need to be supported + # TODO(ANSHUMAN87): False case for flip need to be supported # _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 4], [4, 3], "float32") _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 5], [4, 3], "float32", True) _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 3], [3, 3], "float32", True)