Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion python/tvm/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def schedule_batch_matmul_cublas(_, outs):


@autotvm.register_topi_compute("batch_matmul_int8.cuda")
def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None):
def batch_matmul_int8(
cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
):
"""Batch Matmul operator for int8 on CUDA.

Parameters
Expand All @@ -258,11 +260,20 @@ def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None):
out_dtype : Optional[str]
Specifies the output data type for mixed precision batch matmul.

transpose_a : Optional[bool] = False
Whether the first tensor is in transposed format.

transpose_b : Optional[bool] = True
Whether the second tensor is in transposed format.

Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
del out_shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because this compute doesn't process this? What happen if we just leave it as an unused argument? Or we could check if out_shape is consistent with the inferred output shape and throw a warning if not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just some linting tools (like ctypes) may throw an error on unused parameters. Since we've not enabled the ctypes check here, leave this here will not affect anything.

# TODO(jcf94): Deal with different transpose combinations
assert not transpose_a and transpose_b
if out_dtype is None:
out_dtype = x.dtype

Expand Down
9 changes: 7 additions & 2 deletions python/tvm/topi/cuda/batch_matmul_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,14 @@


@autotvm.register_topi_compute("batch_matmul_tensorcore.cuda")
def batch_matmul_tensorcore(cfg, x, y, out_shape=None, out_dtype=None):
def batch_matmul_tensorcore(
cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
):
"""batch matmul tensorcore operator on cuda"""
# todo: deal with out_shape for broadcast, liuxin.ai
# TODO(jcf94): Deal with different transpose combinations
assert not transpose_a and transpose_b
# TODO(liuxin.ai): Deal with out_shape for broadcast
del out_shape
return batch_matmul_tensorcore_cuda(x, y, out_dtype)


Expand Down
7 changes: 5 additions & 2 deletions python/tvm/topi/rocm/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@


@autotvm.register_topi_compute("batch_matmul_rocblas.rocm")
def batch_matmul_rocblas(cfg, x, y, out_shape=None):
def batch_matmul_rocblas(
cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
):
"""Computes matrix multiplication of `x` and `y` via rocblas when
`x` and `y` are batched matrices.

Expand All @@ -40,12 +42,13 @@ def batch_matmul_rocblas(cfg, x, y, out_shape=None):
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
del out_dtype
batch, M, K = get_const_tuple(x.shape)
_, N, _ = get_const_tuple(y.shape)
if out_shape is not None:
assert out_shape[0] == batch, "Input and output batch sizes must match"
assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape"
result = rocblas.batch_matmul(x, y, False, True)
result = rocblas.batch_matmul(x, y, transpose_a, transpose_b)
cfg.add_flop(batch * M * N * K * 2)
return result

Expand Down