From 29f7b037da4d0a91aa22806c72ab9efa0a109f96 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 9 Jul 2021 11:51:08 -0700 Subject: [PATCH] [TOPI] Add support for arbitrary dtypes to CSRMV and CSRMM --- python/tvm/topi/nn/sparse.py | 12 +++---- python/tvm/topi/sparse/csrmm.py | 15 ++++++--- python/tvm/topi/sparse/csrmv.py | 15 ++++++--- tests/python/topi/python/test_topi_sparse.py | 35 ++++++++++---------- 4 files changed, 45 insertions(+), 32 deletions(-) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 73998db6f162..948847e60d92 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -31,7 +31,7 @@ def sparse_dense_sp_rhs(data, weight_data, weight_indices, weight_indptr): Parameters ---------- data : tvm.te.Tensor - 2-D with shape [M, K], float32 + 2-D with shape [M, K] weight_data : tvm.te.Tensor 1-D with shape [nnz] (CSR) or @@ -78,7 +78,7 @@ def sparse_dense_sp_lhs(data_data, data_indices, data_indptr, weight): 1-D with shape [(M + 1) // bs_r] (BSR) weight: - 2-D with shape [N, K], float32 + 2-D with shape [N, K] Returns ------- @@ -105,7 +105,7 @@ def sparse_dense(dense_data, sparse_data, sparse_indices, sparse_indptr, sparse_ Parameters ---------- dense_data : tvm.te.Tensor - 2-D with shape [M, K], float32 + 2-D with shape [M, K] sparse_data : tvm.te.Tensor 1-D with shape [nnz] (CSR) or @@ -239,7 +239,7 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr): Parameters ---------- sparse_data : tvm.te.Tensor - 1-D with shape [nonzeros], dtype of 'float32' + 1-D with shape [nonzeros] sparse_indices : tvm.te.Tensor 1-D with shape [nonzeros], dtype of 'int32' @@ -250,7 +250,7 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr): Returns ------- out_data : tvm.te.Tensor - 1-D with shape [nonzeros], dtype of 'float32' + 1-D with shape [nonzeros] out_indices : tvm.te.Tensor 1-D with shape [nonzeros], dtype of 'int32' @@ -275,7 +275,7 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr): ins[0], ins[1], ins[2], outs[0], outs[1], outs[2] ), tag="sparse_transpose_csr", - dtype=["float32", "int32", "int32"], + dtype=[sparse_data.dtype, "int32", "int32"], name="out", ) diff --git a/python/tvm/topi/sparse/csrmm.py b/python/tvm/topi/sparse/csrmm.py index 39ba3332fc72..4d659c801103 100644 --- a/python/tvm/topi/sparse/csrmm.py +++ b/python/tvm/topi/sparse/csrmm.py @@ -20,6 +20,7 @@ from tvm import te from .. import tag from ..utils import simplify +from ...tir.generic import cast def csrmm_default(data, indices, indptr, weight, bias=None): @@ -57,6 +58,12 @@ def csrmm_default(data, indices, indptr, weight, bias=None): assert isinstance( weight, te.tensor.Tensor ), "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight)) + assert ( + data.dtype == weight.dtype + ), "Data and weight must have the same dtype, but they have %s and %s" % ( + data.dtype, + weight.dtype, + ) if bias is not None: assert len(bias.shape) == 1 M = simplify(indptr.shape[0] - 1) @@ -74,9 +81,9 @@ def csrmm_default_ir(data, indices, indptr, weight, out): _, N = weight.shape with irb.for_range(0, N, kind="vectorize", name="n") as n: with irb.for_range(0, M, kind="parallel", name="row") as row: - dot = irb.allocate("float32", (1,), name="dot", scope="local") - out_ptr[row * N + n] = 0.0 - dot[0] = 0.0 + dot = irb.allocate(data.dtype, (1,), name="dot", scope="local") + out_ptr[row * N + n] = cast(0, data.dtype) + dot[0] = cast(0, data.dtype) row_start = indptr_ptr[row] row_end = indptr_ptr[row + 1] row_elems = row_end - row_start @@ -92,7 +99,7 @@ def csrmm_default_ir(data, indices, indptr, weight, out): [data, indices, indptr, weight], lambda ins, outs: csrmm_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), tag="csrmm", - dtype="float32", + dtype=data.dtype, name="out", ) if bias is not None: diff --git a/python/tvm/topi/sparse/csrmv.py b/python/tvm/topi/sparse/csrmv.py index a2d22afe01e0..3c2016c6513a 100644 --- a/python/tvm/topi/sparse/csrmv.py +++ b/python/tvm/topi/sparse/csrmv.py @@ -19,6 +19,7 @@ import tvm from tvm import te from .. import tag +from ...tir.generic import cast def csrmv_default(data, indices, indptr, weight, bias=None): @@ -50,6 +51,12 @@ def csrmv_default(data, indices, indptr, weight, bias=None): assert isinstance( weight, te.tensor.Tensor ), "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight)) + assert ( + data.dtype == weight.dtype + ), "Data and weight must have the same dtype, but they have %s and %s" % ( + data.dtype, + weight.dtype, + ) if bias is not None: assert len(bias.shape) == 1 batch = indptr.shape[0] - 1 @@ -64,9 +71,9 @@ def csrmv_default_ir(data, indices, indptr, weight, out): out_ptr = irb.buffer_ptr(out) num_rows = indptr.shape[0] - 1 with irb.for_range(0, num_rows, kind="parallel", name="row") as row: - dot = irb.allocate("float32", (1,), name="dot", scope="local") - out_ptr[row] = 0.0 - dot[0] = 0.0 + dot = irb.allocate(data.dtype, (1,), name="dot", scope="local") + out_ptr[row] = cast(0, data.dtype) + dot[0] = cast(0, data.dtype) row_start = indptr_ptr[row] row_end = indptr_ptr[row + 1] row_elems = row_end - row_start @@ -82,7 +89,7 @@ def csrmv_default_ir(data, indices, indptr, weight, out): [data, indices, indptr, weight], lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), tag="csrmv", - dtype="float32", + dtype=data.dtype, name="csrmv", ) if bias is not None: diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index 003b89f7122a..11006576fea3 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -35,21 +35,20 @@ } -def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True): +def verify_dynamic_csrmv(batch, in_dim, out_dim, dtype, use_bias=True): nr, nc, n = te.var("nr"), te.var("nc"), te.var("n") - dtype = "float32" A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name="A") - B = te.placeholder((in_dim, 1), name="B") - C = te.placeholder((nr,), name="C") + B = te.placeholder((in_dim, 1), dtype=dtype, name="B") + C = te.placeholder((nr,), dtype=dtype, name="C") D = topi.sparse.csrmv(A, B, C if use_bias else None) s = te.create_schedule(D.op) dtype = A.dtype # get the test data def get_ref_data(): - a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype) - 0.5, 0.0) - b_np = np.random.uniform(size=(in_dim, 1)).astype(dtype) - 0.5 - c_np = np.random.uniform(size=(batch,)).astype(dtype) + a_np = np.random.uniform(size=(batch, in_dim), high=100).astype(dtype) + b_np = np.random.uniform(size=(in_dim, 1), high=100).astype(dtype) + c_np = np.random.uniform(size=(batch,), high=100).astype(dtype) if use_bias: d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1)) else: @@ -81,21 +80,20 @@ def check_device(device): check_device(device) -def verify_dynamic_csrmm(batch, in_dim, out_dim, use_bias=True): +def verify_dynamic_csrmm(batch, in_dim, out_dim, dtype, use_bias=True): nr, nc, n = te.var("nr"), te.var("nc"), te.var("n") - dtype = "float32" A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name="A") - B = te.placeholder((in_dim, out_dim), name="B") - C = te.placeholder((nr,), name="C") + B = te.placeholder((in_dim, out_dim), dtype=dtype, name="B") + C = te.placeholder((nr,), dtype=dtype, name="C") D = topi.sparse.csrmm(A, B, C if use_bias else None) s = te.create_schedule(D.op) dtype = A.dtype # get the test data def get_ref_data(): - a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype) - 0.5, 0.0) - b_np = np.random.uniform(size=(in_dim, out_dim)).astype(dtype) - 0.5 - c_np = np.random.uniform(size=(batch,)).astype(dtype) + a_np = np.random.uniform(size=(batch, in_dim), high=100).astype(dtype) + b_np = np.random.uniform(size=(in_dim, out_dim), high=100).astype(dtype) + c_np = np.random.uniform(size=(batch,), high=100).astype(dtype) if use_bias: d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1)) else: @@ -212,14 +210,15 @@ def check_device(device): def test_csrmv(): - verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=False) - verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=True) + verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, dtype="float32", use_bias=False) + verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, dtype="float64", use_bias=True) + verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, dtype="int32", use_bias=True) def test_csrmm(): M, K, N = 5, 7, 2 - verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=False) - verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=True) + verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, dtype="int64", use_bias=False) + verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, dtype="float64", use_bias=True) def test_dense_si():