Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def sparse_dense_sp_rhs(data, weight_data, weight_indices, weight_indptr):
Parameters
----------
data : tvm.te.Tensor
2-D with shape [M, K], float32
2-D with shape [M, K]

weight_data : tvm.te.Tensor
1-D with shape [nnz] (CSR) or
Expand Down Expand Up @@ -78,7 +78,7 @@ def sparse_dense_sp_lhs(data_data, data_indices, data_indptr, weight):
1-D with shape [(M + 1) // bs_r] (BSR)

weight:
2-D with shape [N, K], float32
2-D with shape [N, K]

Returns
-------
Expand All @@ -105,7 +105,7 @@ def sparse_dense(dense_data, sparse_data, sparse_indices, sparse_indptr, sparse_
Parameters
----------
dense_data : tvm.te.Tensor
2-D with shape [M, K], float32
2-D with shape [M, K]

sparse_data : tvm.te.Tensor
1-D with shape [nnz] (CSR) or
Expand Down Expand Up @@ -239,7 +239,7 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr):
Parameters
----------
sparse_data : tvm.te.Tensor
1-D with shape [nonzeros], dtype of 'float32'
1-D with shape [nonzeros]

sparse_indices : tvm.te.Tensor
1-D with shape [nonzeros], dtype of 'int32'
Expand All @@ -250,7 +250,7 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr):
Returns
-------
out_data : tvm.te.Tensor
1-D with shape [nonzeros], dtype of 'float32'
1-D with shape [nonzeros]

out_indices : tvm.te.Tensor
1-D with shape [nonzeros], dtype of 'int32'
Expand All @@ -275,7 +275,7 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr):
ins[0], ins[1], ins[2], outs[0], outs[1], outs[2]
),
tag="sparse_transpose_csr",
dtype=["float32", "int32", "int32"],
dtype=[sparse_data.dtype, "int32", "int32"],
name="out",
)

Expand Down
15 changes: 11 additions & 4 deletions python/tvm/topi/sparse/csrmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm import te
from .. import tag
from ..utils import simplify
from ...tir.generic import cast


def csrmm_default(data, indices, indptr, weight, bias=None):
Expand Down Expand Up @@ -57,6 +58,12 @@ def csrmm_default(data, indices, indptr, weight, bias=None):
assert isinstance(
weight, te.tensor.Tensor
), "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight))
assert (
data.dtype == weight.dtype
), "Data and weight must have the same dtype, but they have %s and %s" % (
data.dtype,
weight.dtype,
)
if bias is not None:
assert len(bias.shape) == 1
M = simplify(indptr.shape[0] - 1)
Expand All @@ -74,9 +81,9 @@ def csrmm_default_ir(data, indices, indptr, weight, out):
_, N = weight.shape
with irb.for_range(0, N, kind="vectorize", name="n") as n:
with irb.for_range(0, M, kind="parallel", name="row") as row:
dot = irb.allocate("float32", (1,), name="dot", scope="local")
out_ptr[row * N + n] = 0.0
dot[0] = 0.0
dot = irb.allocate(data.dtype, (1,), name="dot", scope="local")
out_ptr[row * N + n] = cast(0, data.dtype)
dot[0] = cast(0, data.dtype)
row_start = indptr_ptr[row]
row_end = indptr_ptr[row + 1]
row_elems = row_end - row_start
Expand All @@ -92,7 +99,7 @@ def csrmm_default_ir(data, indices, indptr, weight, out):
[data, indices, indptr, weight],
lambda ins, outs: csrmm_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
tag="csrmm",
dtype="float32",
dtype=data.dtype,
name="out",
)
if bias is not None:
Expand Down
15 changes: 11 additions & 4 deletions python/tvm/topi/sparse/csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm
from tvm import te
from .. import tag
from ...tir.generic import cast


def csrmv_default(data, indices, indptr, weight, bias=None):
Expand Down Expand Up @@ -50,6 +51,12 @@ def csrmv_default(data, indices, indptr, weight, bias=None):
assert isinstance(
weight, te.tensor.Tensor
), "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight))
assert (
data.dtype == weight.dtype
), "Data and weight must have the same dtype, but they have %s and %s" % (
data.dtype,
weight.dtype,
)
if bias is not None:
assert len(bias.shape) == 1
batch = indptr.shape[0] - 1
Expand All @@ -64,9 +71,9 @@ def csrmv_default_ir(data, indices, indptr, weight, out):
out_ptr = irb.buffer_ptr(out)
num_rows = indptr.shape[0] - 1
with irb.for_range(0, num_rows, kind="parallel", name="row") as row:
dot = irb.allocate("float32", (1,), name="dot", scope="local")
out_ptr[row] = 0.0
dot[0] = 0.0
dot = irb.allocate(data.dtype, (1,), name="dot", scope="local")
out_ptr[row] = cast(0, data.dtype)
dot[0] = cast(0, data.dtype)
row_start = indptr_ptr[row]
row_end = indptr_ptr[row + 1]
row_elems = row_end - row_start
Expand All @@ -82,7 +89,7 @@ def csrmv_default_ir(data, indices, indptr, weight, out):
[data, indices, indptr, weight],
lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
tag="csrmv",
dtype="float32",
dtype=data.dtype,
name="csrmv",
)
if bias is not None:
Expand Down
35 changes: 17 additions & 18 deletions tests/python/topi/python/test_topi_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,20 @@
}


def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True):
def verify_dynamic_csrmv(batch, in_dim, out_dim, dtype, use_bias=True):
nr, nc, n = te.var("nr"), te.var("nc"), te.var("n")
dtype = "float32"
A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name="A")
B = te.placeholder((in_dim, 1), name="B")
C = te.placeholder((nr,), name="C")
B = te.placeholder((in_dim, 1), dtype=dtype, name="B")
C = te.placeholder((nr,), dtype=dtype, name="C")
D = topi.sparse.csrmv(A, B, C if use_bias else None)
s = te.create_schedule(D.op)
dtype = A.dtype

# get the test data
def get_ref_data():
a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype) - 0.5, 0.0)
b_np = np.random.uniform(size=(in_dim, 1)).astype(dtype) - 0.5
c_np = np.random.uniform(size=(batch,)).astype(dtype)
a_np = np.random.uniform(size=(batch, in_dim), high=100).astype(dtype)
b_np = np.random.uniform(size=(in_dim, 1), high=100).astype(dtype)
c_np = np.random.uniform(size=(batch,), high=100).astype(dtype)
if use_bias:
d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1))
else:
Expand Down Expand Up @@ -81,21 +80,20 @@ def check_device(device):
check_device(device)


def verify_dynamic_csrmm(batch, in_dim, out_dim, use_bias=True):
def verify_dynamic_csrmm(batch, in_dim, out_dim, dtype, use_bias=True):
nr, nc, n = te.var("nr"), te.var("nc"), te.var("n")
dtype = "float32"
A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name="A")
B = te.placeholder((in_dim, out_dim), name="B")
C = te.placeholder((nr,), name="C")
B = te.placeholder((in_dim, out_dim), dtype=dtype, name="B")
C = te.placeholder((nr,), dtype=dtype, name="C")
D = topi.sparse.csrmm(A, B, C if use_bias else None)
s = te.create_schedule(D.op)
dtype = A.dtype

# get the test data
def get_ref_data():
a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype) - 0.5, 0.0)
b_np = np.random.uniform(size=(in_dim, out_dim)).astype(dtype) - 0.5
c_np = np.random.uniform(size=(batch,)).astype(dtype)
a_np = np.random.uniform(size=(batch, in_dim), high=100).astype(dtype)
b_np = np.random.uniform(size=(in_dim, out_dim), high=100).astype(dtype)
c_np = np.random.uniform(size=(batch,), high=100).astype(dtype)
if use_bias:
d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1))
else:
Expand Down Expand Up @@ -212,14 +210,15 @@ def check_device(device):


def test_csrmv():
verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=False)
verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=True)
verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, dtype="float32", use_bias=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As both are 2 diff op, I would suggest to maintain uniformity in Data type input for both the cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tests both dtype and use_bias at the same time. I'm trying to avoid adding too much testing overhead. If necessary, I can split it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to add same datatype test cases for both. Please choose to ignore, it is minor. 👍

verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, dtype="float64", use_bias=True)
verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, dtype="int32", use_bias=True)


def test_csrmm():
M, K, N = 5, 7, 2
verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=False)
verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=True)
verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, dtype="int64", use_bias=False)
verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, dtype="float64", use_bias=True)


def test_dense_si():
Expand Down