From f5e5c9554ffccc8e7adbc39de2401bce324d91ba Mon Sep 17 00:00:00 2001 From: Ubospica Date: Fri, 3 Nov 2023 00:37:26 +0000 Subject: [PATCH 1/2] 1102 --- python/tvm/topi/nn/dense.py | 124 +++++++++++-------- tests/python/topi/python/test_topi_matmul.py | 61 +++++++-- 2 files changed, 123 insertions(+), 62 deletions(-) diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index ce3aebadb692..49c72225eb02 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -19,7 +19,7 @@ import tvm from tvm import auto_scheduler, te -from .. import tag +from .. import tag, add def matmul( @@ -65,86 +65,102 @@ def matmul( output : tvm.te.Tensor 2-D with shape [batch, out_dim] """ - # TODO(jcf94): Add multi-dim support for tensor_a - assert len(tensor_a.shape) == 2, "only support 2-dim matmul" + # TODO(yixin): support cases for 1-dim input + assert ( + len(tensor_a.shape) >= 2 and len(tensor_b.shape) >= 2 + ), "1-dim matmul is not supported yet." if bias is not None: assert len(bias.shape) == 1 if out_dtype is None: out_dtype = tensor_a.dtype if transpose_a: - in_dim, batch = tensor_a.shape + red_dim_a, in_dim = tensor_a.shape[-2:] else: - batch, in_dim = tensor_a.shape + in_dim, red_dim_a = tensor_a.shape[-2:] + batch_dims_a = tensor_a.shape[:-2] if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout - out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout( + assert len(tensor_b).shape == 2, "only support 2-dim matmul when using auto-scheduler" + out_dim, red_dim_b = auto_scheduler.get_shape_from_rewritten_layout( auto_scheduler_rewritten_layout, ["j", "k"] ) auto_scheduler.remove_index_check(tensor_b) elif meta_schedule_original_shape: auto_scheduler.rewrite_tensor_shape(tensor_b, meta_schedule_original_shape) if transpose_b: - out_dim, red_dim = tensor_b.shape + out_dim, red_dim_b = tensor_b.shape[-2:] else: - red_dim, out_dim = tensor_b.shape + red_dim_b, out_dim = tensor_b.shape[-2:] elif transpose_b: - out_dim, red_dim = tensor_b.shape + out_dim, red_dim_b = tensor_b.shape[-2:] else: - red_dim, out_dim = tensor_b.shape - - # cmp should be done by values - condition = True - if isinstance(in_dim, tvm.tir.SizeVar): # "any_dim" - condition = False - elif isinstance(red_dim, tvm.tir.SizeVar): # "any_dim" - condition = False - if condition: - assert int(in_dim) == int( - red_dim - ), "Inner dimensions of dense do not match. {in_dim} vs {red_dim}." - - k = te.reduce_axis((0, in_dim), name="k") - if (transpose_a, transpose_b) == (True, True): - compute_lambda = lambda i, j: te.sum( - tensor_a[k, i].astype(out_dtype) * tensor_b[j, k].astype(out_dtype), axis=k + red_dim_b, out_dim = tensor_b.shape[-2:] + batch_dims_b = tensor_b.shape[:-2] + + if not isinstance(red_dim_a, tvm.tir.Var) and not isinstance(red_dim_b, tvm.tir.Var): + assert int(red_dim_a) == int( + red_dim_b + ), f"Reduction dimensions of dense do not match. {red_dim_a} vs {red_dim_b}." + + result_ndim = max(len(batch_dims_a), len(batch_dims_b)) + batch_dims_a = [1] * (result_ndim - len(batch_dims_a)) + batch_dims_a + batch_dims_b = [1] * (result_ndim - len(batch_dims_b)) + batch_dims_b + + for idx, (l, r) in enumerate(zip(batch_dims_a, batch_dims_b)): + if ( + not isinstance(l, tvm.tir.Var) + and not isinstance(r, tvm.tir.Var) + and int(l) != 1 + and int(r) != 1 + ): + assert int(l) == int(r), ( + "Batch dimensions of dense do not match: " + f"{tensor_a.shape[:-2]} vs {tensor_b.shape[:-2]}." + ) + if not isinstance(l, tvm.tir.Var) and int(l) == 1: + batch_dims_a[idx] = batch_dims_b[idx] + + k = te.reduce_axis((0, red_dim_a), name="k") + + def compute(*indices): + batch_indices_a = indices[-len(tensor_a.shape) : -2] + batch_indices_a = [ + i if isinstance(dim, tvm.tir.Var) or int(dim) != 1 else 0 + for i, dim in zip(batch_indices_a, tensor_a.shape[:-2]) + ] + batch_indices_b = indices[-len(tensor_b.shape) : -2] + batch_indices_b = [ + i if isinstance(dim, tvm.tir.Var) or int(dim) != 1 else 0 + for i, dim in zip(batch_indices_b, tensor_b.shape[:-2]) + ] + i, j = indices[-2:] + a_indices = (*batch_indices_a, k, i) if transpose_a else (*batch_indices_a, i, k) + b_indices = (*batch_indices_b, j, k) if transpose_b else (*batch_indices_b, k, j) + return te.sum( + tensor_a[a_indices].astype(out_dtype) * tensor_b[b_indices].astype(out_dtype), axis=k ) - compute_name = "T_matmul_TT" - compute_tag = "matmul" - elif (transpose_a, transpose_b) == (True, False): - compute_lambda = lambda i, j: te.sum( - tensor_a[k, i].astype(out_dtype) * tensor_b[k, j].astype(out_dtype), axis=k - ) - compute_name = "T_matmul_TN" - compute_tag = "matmul" - elif (transpose_a, transpose_b) == (False, True): - compute_lambda = lambda i, j: te.sum( - tensor_a[i, k].astype(out_dtype) * tensor_b[j, k].astype(out_dtype), axis=k - ) - compute_name = "T_matmul_NT" - # TODO(jcf94): Remove `dense` when `matmul` is finally ready - compute_tag = "dense" - else: # (transpose_a, transpose_b) == (False, False): - compute_lambda = lambda i, j: te.sum( - tensor_a[i, k].astype(out_dtype) * tensor_b[k, j].astype(out_dtype), axis=k - ) - compute_name = "T_matmul_NN" - compute_tag = "matmul" + + compute_name = { + (True, True): "T_matmul_TT", + (True, False): "T_matmul_TN", + (False, True): "T_matmul_NT", + (False, False): "T_matmul_NN", + }[(transpose_a, transpose_b)] + + # TODO(jcf94): Remove `dense` when `matmul` is finally ready + compute_tag = "dense" if (transpose_a, transpose_b) == (False, True) else "matmul" mat = te.compute( - (batch, out_dim), - compute_lambda, + (*batch_dims_a, in_dim, out_dim), + compute, name=compute_name, tag=compute_tag, attrs={"layout_free_placeholders": [tensor_b]}, ) if bias is not None: - mat = te.compute( - (batch, out_dim), - lambda i, j: mat[i, j] + bias[j].astype(out_dtype), - tag=tag.BROADCAST, - ) + mat = add(mat, bias.astype(out_dtype)) if auto_scheduler_rewritten_layout: mat = auto_scheduler.rewrite_compute_body(mat, auto_scheduler_rewritten_layout) diff --git a/tests/python/topi/python/test_topi_matmul.py b/tests/python/topi/python/test_topi_matmul.py index de2d4d3c4c8e..4b05dd3813e2 100644 --- a/tests/python/topi/python/test_topi_matmul.py +++ b/tests/python/topi/python/test_topi_matmul.py @@ -41,15 +41,40 @@ def with_tvm(lam, *args): return out_nd.numpy() -def verify_nn_matmul(sa, sb, transp_a, transp_b): +def verify_nn_matmul(sa, sb, transp_a, transp_b, bias=False): a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32) b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32) - c1 = np.matmul(np.transpose(a) if transp_a else a, np.transpose(b) if transp_b else b) - c2 = with_tvm( - lambda A, B: topi.nn.matmul(A, B, transpose_a=transp_a, transpose_b=transp_b), - a, - b, - ) + if bias: + bias_shape = sb[-2] if transp_b else sb[-1] + bias_np = np.random.uniform(low=-1.0, high=1.0, size=(bias_shape,)).astype(np.float32) + + a_np = a + if transp_a: + axes = list(range(len(sa))) + axes[-2], axes[-1] = axes[-1], axes[-2] + a_np = np.transpose(a_np, axes) + b_np = b + if transp_b: + axes = list(range(len(sb))) + axes[-2], axes[-1] = axes[-1], axes[-2] + b_np = np.transpose(b_np, axes) + + if bias: + c1 = np.matmul(a_np, b_np) + bias_np + c2 = with_tvm( + lambda A, B, bias: topi.nn.matmul( + A, B, transpose_a=transp_a, transpose_b=transp_b, bias=bias + ), + a, + b, + bias_np, + ) + else: + c1 = np.matmul(a_np, b_np) + c2 = with_tvm( + lambda A, B: topi.nn.matmul(A, B, transpose_a=transp_a, transpose_b=transp_b), a, b + ) + tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5) @@ -60,10 +85,30 @@ def test_nn_matmul(): verify_nn_matmul((2, 2), (2, 2), True, True) verify_nn_matmul((2, 3), (3, 5), False, False) verify_nn_matmul((5, 3), (3, 2), False, False) - verify_nn_matmul((3, 5), (3, 2), True, False) verify_nn_matmul((3, 5), (2, 3), True, True) verify_nn_matmul((3, 5), (3, 2), True, False) verify_nn_matmul((5, 3), (2, 3), False, True) + # matmul with bias + verify_nn_matmul((5, 3), (3, 2), False, False, True) + verify_nn_matmul((3, 5), (2, 3), True, True, True) + verify_nn_matmul((3, 5), (3, 2), True, False, True) + verify_nn_matmul((5, 3), (2, 3), False, True, True) + # batched matmul + verify_nn_matmul((4, 5, 3), (4, 3, 2), False, False) + verify_nn_matmul((4, 3, 5), (4, 2, 3), True, True) + verify_nn_matmul((4, 3, 5), (4, 3, 2), True, False) + verify_nn_matmul((4, 5, 3), (4, 2, 3), False, True) + # batched matmul with broadcast + verify_nn_matmul((4, 5, 3), (1, 2, 3), False, True) + verify_nn_matmul((1, 5, 3), (4, 2, 3), False, True) + verify_nn_matmul((5, 3), (4, 2, 3), False, True) + verify_nn_matmul((4, 5, 3), (2, 3), False, True) + verify_nn_matmul((2, 4, 5, 3), (1, 2, 3), False, True) + # batched matmul with bias + verify_nn_matmul((4, 5, 3), (4, 3, 2), False, False, True) + verify_nn_matmul((4, 3, 5), (4, 2, 3), True, True, True) + verify_nn_matmul((4, 3, 5), (4, 3, 2), True, False, True) + verify_nn_matmul((4, 5, 3), (4, 2, 3), False, True, True) def verify_matmul(sa, sb, transp_a, transp_b): From c317d13265f09a1b541012e77d71be6d10c0e08b Mon Sep 17 00:00:00 2001 From: Ubospica Date: Thu, 9 Nov 2023 09:17:01 +0000 Subject: [PATCH 2/2] 1109 --- python/tvm/topi/nn/dense.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index 49c72225eb02..d81060fe8baa 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -66,6 +66,7 @@ def matmul( 2-D with shape [batch, out_dim] """ # TODO(yixin): support cases for 1-dim input + # TODO(yixin): adding support and further check for >2-dim input in autotvm template assert ( len(tensor_a.shape) >= 2 and len(tensor_b.shape) >= 2 ), "1-dim matmul is not supported yet." @@ -74,34 +75,34 @@ def matmul( if out_dtype is None: out_dtype = tensor_a.dtype if transpose_a: - red_dim_a, in_dim = tensor_a.shape[-2:] + reduce_dim_a, in_dim = tensor_a.shape[-2:] else: - in_dim, red_dim_a = tensor_a.shape[-2:] + in_dim, reduce_dim_a = tensor_a.shape[-2:] batch_dims_a = tensor_a.shape[:-2] if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout assert len(tensor_b).shape == 2, "only support 2-dim matmul when using auto-scheduler" - out_dim, red_dim_b = auto_scheduler.get_shape_from_rewritten_layout( + out_dim, reduce_dim_b = auto_scheduler.get_shape_from_rewritten_layout( auto_scheduler_rewritten_layout, ["j", "k"] ) auto_scheduler.remove_index_check(tensor_b) elif meta_schedule_original_shape: auto_scheduler.rewrite_tensor_shape(tensor_b, meta_schedule_original_shape) if transpose_b: - out_dim, red_dim_b = tensor_b.shape[-2:] + out_dim, reduce_dim_b = tensor_b.shape[-2:] else: - red_dim_b, out_dim = tensor_b.shape[-2:] + reduce_dim_b, out_dim = tensor_b.shape[-2:] elif transpose_b: - out_dim, red_dim_b = tensor_b.shape[-2:] + out_dim, reduce_dim_b = tensor_b.shape[-2:] else: - red_dim_b, out_dim = tensor_b.shape[-2:] + reduce_dim_b, out_dim = tensor_b.shape[-2:] batch_dims_b = tensor_b.shape[:-2] - if not isinstance(red_dim_a, tvm.tir.Var) and not isinstance(red_dim_b, tvm.tir.Var): - assert int(red_dim_a) == int( - red_dim_b - ), f"Reduction dimensions of dense do not match. {red_dim_a} vs {red_dim_b}." + if not isinstance(reduce_dim_a, tvm.tir.Var) and not isinstance(reduce_dim_b, tvm.tir.Var): + assert int(reduce_dim_a) == int( + reduce_dim_b + ), f"Reduction dimensions of dense do not match. {reduce_dim_a} vs {reduce_dim_b}." result_ndim = max(len(batch_dims_a), len(batch_dims_b)) batch_dims_a = [1] * (result_ndim - len(batch_dims_a)) + batch_dims_a @@ -121,7 +122,7 @@ def matmul( if not isinstance(l, tvm.tir.Var) and int(l) == 1: batch_dims_a[idx] = batch_dims_b[idx] - k = te.reduce_axis((0, red_dim_a), name="k") + k = te.reduce_axis((0, reduce_dim_a), name="k") def compute(*indices): batch_indices_a = indices[-len(tensor_a.shape) : -2]