Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 71 additions & 54 deletions python/tvm/topi/nn/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tvm
from tvm import auto_scheduler, te

from .. import tag
from .. import tag, add


def matmul(
Expand Down Expand Up @@ -65,86 +65,103 @@ def matmul(
output : tvm.te.Tensor
2-D with shape [batch, out_dim]
"""
# TODO(jcf94): Add multi-dim support for tensor_a
assert len(tensor_a.shape) == 2, "only support 2-dim matmul"
# TODO(yixin): support cases for 1-dim input
# TODO(yixin): adding support and further check for >2-dim input in autotvm template
assert (
len(tensor_a.shape) >= 2 and len(tensor_b.shape) >= 2
), "1-dim matmul is not supported yet."
if bias is not None:
assert len(bias.shape) == 1
if out_dtype is None:
out_dtype = tensor_a.dtype
if transpose_a:
in_dim, batch = tensor_a.shape
reduce_dim_a, in_dim = tensor_a.shape[-2:]
else:
batch, in_dim = tensor_a.shape
in_dim, reduce_dim_a = tensor_a.shape[-2:]
batch_dims_a = tensor_a.shape[:-2]

if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
assert len(tensor_b).shape == 2, "only support 2-dim matmul when using auto-scheduler"
Copy link

@happyme531 happyme531 Dec 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be len(tensor_b.shape). (nobody probably tested this)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Ubospica please followup

out_dim, reduce_dim_b = auto_scheduler.get_shape_from_rewritten_layout(
auto_scheduler_rewritten_layout, ["j", "k"]
)
auto_scheduler.remove_index_check(tensor_b)
elif meta_schedule_original_shape:
auto_scheduler.rewrite_tensor_shape(tensor_b, meta_schedule_original_shape)
if transpose_b:
out_dim, red_dim = tensor_b.shape
out_dim, reduce_dim_b = tensor_b.shape[-2:]
else:
red_dim, out_dim = tensor_b.shape
reduce_dim_b, out_dim = tensor_b.shape[-2:]
elif transpose_b:
out_dim, red_dim = tensor_b.shape
out_dim, reduce_dim_b = tensor_b.shape[-2:]
else:
red_dim, out_dim = tensor_b.shape

# cmp should be done by values
condition = True
if isinstance(in_dim, tvm.tir.SizeVar): # "any_dim"
condition = False
elif isinstance(red_dim, tvm.tir.SizeVar): # "any_dim"
condition = False
if condition:
assert int(in_dim) == int(
red_dim
), "Inner dimensions of dense do not match. {in_dim} vs {red_dim}."

k = te.reduce_axis((0, in_dim), name="k")
if (transpose_a, transpose_b) == (True, True):
compute_lambda = lambda i, j: te.sum(
tensor_a[k, i].astype(out_dtype) * tensor_b[j, k].astype(out_dtype), axis=k
reduce_dim_b, out_dim = tensor_b.shape[-2:]
batch_dims_b = tensor_b.shape[:-2]

if not isinstance(reduce_dim_a, tvm.tir.Var) and not isinstance(reduce_dim_b, tvm.tir.Var):
assert int(reduce_dim_a) == int(
reduce_dim_b
), f"Reduction dimensions of dense do not match. {reduce_dim_a} vs {reduce_dim_b}."

result_ndim = max(len(batch_dims_a), len(batch_dims_b))
batch_dims_a = [1] * (result_ndim - len(batch_dims_a)) + batch_dims_a
batch_dims_b = [1] * (result_ndim - len(batch_dims_b)) + batch_dims_b

for idx, (l, r) in enumerate(zip(batch_dims_a, batch_dims_b)):
if (
not isinstance(l, tvm.tir.Var)
and not isinstance(r, tvm.tir.Var)
and int(l) != 1
and int(r) != 1
):
assert int(l) == int(r), (
"Batch dimensions of dense do not match: "
f"{tensor_a.shape[:-2]} vs {tensor_b.shape[:-2]}."
)
if not isinstance(l, tvm.tir.Var) and int(l) == 1:
batch_dims_a[idx] = batch_dims_b[idx]

k = te.reduce_axis((0, reduce_dim_a), name="k")

def compute(*indices):
batch_indices_a = indices[-len(tensor_a.shape) : -2]
batch_indices_a = [
i if isinstance(dim, tvm.tir.Var) or int(dim) != 1 else 0
for i, dim in zip(batch_indices_a, tensor_a.shape[:-2])
]
batch_indices_b = indices[-len(tensor_b.shape) : -2]
batch_indices_b = [
i if isinstance(dim, tvm.tir.Var) or int(dim) != 1 else 0
for i, dim in zip(batch_indices_b, tensor_b.shape[:-2])
]
i, j = indices[-2:]
a_indices = (*batch_indices_a, k, i) if transpose_a else (*batch_indices_a, i, k)
b_indices = (*batch_indices_b, j, k) if transpose_b else (*batch_indices_b, k, j)
return te.sum(
tensor_a[a_indices].astype(out_dtype) * tensor_b[b_indices].astype(out_dtype), axis=k
)
compute_name = "T_matmul_TT"
compute_tag = "matmul"
elif (transpose_a, transpose_b) == (True, False):
compute_lambda = lambda i, j: te.sum(
tensor_a[k, i].astype(out_dtype) * tensor_b[k, j].astype(out_dtype), axis=k
)
compute_name = "T_matmul_TN"
compute_tag = "matmul"
elif (transpose_a, transpose_b) == (False, True):
compute_lambda = lambda i, j: te.sum(
tensor_a[i, k].astype(out_dtype) * tensor_b[j, k].astype(out_dtype), axis=k
)
compute_name = "T_matmul_NT"
# TODO(jcf94): Remove `dense` when `matmul` is finally ready
compute_tag = "dense"
else: # (transpose_a, transpose_b) == (False, False):
compute_lambda = lambda i, j: te.sum(
tensor_a[i, k].astype(out_dtype) * tensor_b[k, j].astype(out_dtype), axis=k
)
compute_name = "T_matmul_NN"
compute_tag = "matmul"

compute_name = {
(True, True): "T_matmul_TT",
(True, False): "T_matmul_TN",
(False, True): "T_matmul_NT",
(False, False): "T_matmul_NN",
}[(transpose_a, transpose_b)]

# TODO(jcf94): Remove `dense` when `matmul` is finally ready
compute_tag = "dense" if (transpose_a, transpose_b) == (False, True) else "matmul"

mat = te.compute(
(batch, out_dim),
compute_lambda,
(*batch_dims_a, in_dim, out_dim),
compute,
name=compute_name,
tag=compute_tag,
attrs={"layout_free_placeholders": [tensor_b]},
)

if bias is not None:
mat = te.compute(
(batch, out_dim),
lambda i, j: mat[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST,
)
mat = add(mat, bias.astype(out_dtype))

if auto_scheduler_rewritten_layout:
mat = auto_scheduler.rewrite_compute_body(mat, auto_scheduler_rewritten_layout)
Expand Down
61 changes: 53 additions & 8 deletions tests/python/topi/python/test_topi_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,40 @@ def with_tvm(lam, *args):
return out_nd.numpy()


def verify_nn_matmul(sa, sb, transp_a, transp_b):
def verify_nn_matmul(sa, sb, transp_a, transp_b, bias=False):
a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
c1 = np.matmul(np.transpose(a) if transp_a else a, np.transpose(b) if transp_b else b)
c2 = with_tvm(
lambda A, B: topi.nn.matmul(A, B, transpose_a=transp_a, transpose_b=transp_b),
a,
b,
)
if bias:
bias_shape = sb[-2] if transp_b else sb[-1]
bias_np = np.random.uniform(low=-1.0, high=1.0, size=(bias_shape,)).astype(np.float32)

a_np = a
if transp_a:
axes = list(range(len(sa)))
axes[-2], axes[-1] = axes[-1], axes[-2]
a_np = np.transpose(a_np, axes)
b_np = b
if transp_b:
axes = list(range(len(sb)))
axes[-2], axes[-1] = axes[-1], axes[-2]
b_np = np.transpose(b_np, axes)

if bias:
c1 = np.matmul(a_np, b_np) + bias_np
c2 = with_tvm(
lambda A, B, bias: topi.nn.matmul(
A, B, transpose_a=transp_a, transpose_b=transp_b, bias=bias
),
a,
b,
bias_np,
)
else:
c1 = np.matmul(a_np, b_np)
c2 = with_tvm(
lambda A, B: topi.nn.matmul(A, B, transpose_a=transp_a, transpose_b=transp_b), a, b
)

tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)


Expand All @@ -60,10 +85,30 @@ def test_nn_matmul():
verify_nn_matmul((2, 2), (2, 2), True, True)
verify_nn_matmul((2, 3), (3, 5), False, False)
verify_nn_matmul((5, 3), (3, 2), False, False)
verify_nn_matmul((3, 5), (3, 2), True, False)
verify_nn_matmul((3, 5), (2, 3), True, True)
verify_nn_matmul((3, 5), (3, 2), True, False)
verify_nn_matmul((5, 3), (2, 3), False, True)
# matmul with bias
verify_nn_matmul((5, 3), (3, 2), False, False, True)
verify_nn_matmul((3, 5), (2, 3), True, True, True)
verify_nn_matmul((3, 5), (3, 2), True, False, True)
verify_nn_matmul((5, 3), (2, 3), False, True, True)
# batched matmul
verify_nn_matmul((4, 5, 3), (4, 3, 2), False, False)
verify_nn_matmul((4, 3, 5), (4, 2, 3), True, True)
verify_nn_matmul((4, 3, 5), (4, 3, 2), True, False)
verify_nn_matmul((4, 5, 3), (4, 2, 3), False, True)
# batched matmul with broadcast
verify_nn_matmul((4, 5, 3), (1, 2, 3), False, True)
verify_nn_matmul((1, 5, 3), (4, 2, 3), False, True)
verify_nn_matmul((5, 3), (4, 2, 3), False, True)
verify_nn_matmul((4, 5, 3), (2, 3), False, True)
verify_nn_matmul((2, 4, 5, 3), (1, 2, 3), False, True)
# batched matmul with bias
verify_nn_matmul((4, 5, 3), (4, 3, 2), False, False, True)
verify_nn_matmul((4, 3, 5), (4, 2, 3), True, True, True)
verify_nn_matmul((4, 3, 5), (4, 3, 2), True, False, True)
verify_nn_matmul((4, 5, 3), (4, 2, 3), False, True, True)


def verify_matmul(sa, sb, transp_a, transp_b):
Expand Down