From 13649ce00bfdc99b6e592e042746fa67d8f07442 Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Fri, 2 Jul 2021 17:02:14 +0800 Subject: [PATCH 01/11] add int8/int tensorcore for dense/batch_matmul --- python/tvm/relay/op/strategy/cuda.py | 21 ++- .../tvm/topi/cuda/batch_matmul_tensorcore.py | 75 +++++----- python/tvm/topi/cuda/dense_tensorcore.py | 75 +++++----- python/tvm/topi/cuda/tensorcore_alter_op.py | 140 ++++++++++-------- python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/dense.py | 53 +++++++ .../relay/test_pass_legalize_tensorcore.py | 68 ++++++--- .../test_topi_batch_matmul_tensorcore.py | 67 +++++++-- .../topi/python/test_topi_dense_tensorcore.py | 100 ++++++++++--- 9 files changed, 405 insertions(+), 195 deletions(-) create mode 100644 python/tvm/topi/testing/dense.py diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index dd265e4b4d5b..5a80c48b3f15 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -832,13 +832,24 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): x, y = inputs _, M, K = get_const_tuple(x.shape) _, N, K = get_const_tuple(y.shape) - if x.dtype in ["float16", "int8", "uint8"] and ( - (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) - or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) - or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) + if ( + ( + x.dtype in ["float16", "int8", "uint8"] + and ( + (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) + or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) + or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) + ) + ) + or ( + x.dtype in ["int4", "uint4"] + and K % 32 == 0 + and M % 8 == 0 + and N % 8 == 0 + ) ): strategy.add_implementation( - wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore), + wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore, need_out_dtype=True), wrap_topi_schedule(topi.cuda.schedule_batch_matmul_tensorcore), name="batch_matmul_tensorcore.cuda", plevel=20, diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index 962a8af7853b..67dd6d8c892e 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -29,10 +29,10 @@ @autotvm.register_topi_compute("batch_matmul_tensorcore.cuda") -def batch_matmul_tensorcore(cfg, x, y, out_shape=None): +def batch_matmul_tensorcore(cfg, x, y, out_shape=None, out_dtype=None): """batch matmul tensorcore operator on cuda""" # todo: deal with out_shape for broadcast, liuxin.ai - return batch_matmul_tensorcore_cuda(x, y) + return batch_matmul_tensorcore_cuda(x, y, out_dtype) @autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda") @@ -57,10 +57,8 @@ def _schedule(cfg, s, C): A, B = s[C].op.input_tensors batch, m_dim, k_dim = get_const_tuple(A.shape) batch, n_dim, k_dim = get_const_tuple(B.shape) + data_dtype = A.dtype out_dtype = C.dtype - # inline astype fp16 - s[A].compute_inline() - s[B].compute_inline() # Explicit memory access AS = s.cache_read(A, "shared", [C]) @@ -94,15 +92,26 @@ def _schedule(cfg, s, C): cfg.define_knob("vec", [1, 2, 4, 8]) # Ensure that the default parameters are applicable when autotvm is not in use - if m_dim % 32 == 0 and n_dim % 8 == 0: - cfg.define_knob("wmma_m", [32, 16, 8]) - elif m_dim % 16 == 0 and n_dim % 16 == 0: - cfg.define_knob("wmma_m", [16, 8, 32]) - elif m_dim % 8 == 0 and n_dim % 32 == 0: - cfg.define_knob("wmma_m", [8, 16, 32]) + if data_dtype in ["float16", "uint8", "int8"]: + if m_dim % 32 == 0 and n_dim % 8 == 0: + cfg.define_knob("wmma_m", [32, 16, 8]) + elif m_dim % 16 == 0 and n_dim % 16 == 0: + cfg.define_knob("wmma_m", [16, 8, 32]) + elif m_dim % 8 == 0 and n_dim % 32 == 0: + cfg.define_knob("wmma_m", [8, 16, 32]) + wmma_k = 16 + wmma_m = cfg["wmma_m"].val + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + else: + wmma_m = wmma_n = 8 + wmma_k = 32 warp_size = 32 - wmma_k = 16 block_row_warps = cfg["block_row_warps"].val block_col_warps = cfg["block_col_warps"].val warp_row_tiles = cfg["warp_row_tiles"].val @@ -110,16 +119,8 @@ def _schedule(cfg, s, C): chunk = cfg["chunk"].val offset = cfg["offset"].val offsetCS = cfg["offsetCS"].val - wmma_m = cfg["wmma_m"].val vec = cfg["vec"].val - if wmma_m == 16: - wmma_n = 16 - elif wmma_m == 8: - wmma_n = 32 - elif wmma_m == 32: - wmma_n = 8 - # Define the stride of intrin functions AS_align = chunk * wmma_k + offset BS_align = chunk * wmma_k + offset @@ -211,10 +212,8 @@ def shared_shedule(stage, strides): shared_shedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) - # TODO: add checking here, datatype casting may cause precision loss - in_dtype = "float16" - AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype) - BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype) + AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype) + BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=data_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm") CL_compute = te.compute( (wmma_m, wmma_n), @@ -236,7 +235,7 @@ def shared_shedule(stage, strides): "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), - "float16", + data_dtype, ), ) s[BF].tensorize( @@ -248,7 +247,7 @@ def shared_shedule(stage, strides): "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), - "float16", + data_dtype, ), ) s[CF].tensorize( @@ -270,7 +269,7 @@ def _callback(op): return s -def batch_matmul_tensorcore_cuda(x, y): +def batch_matmul_tensorcore_cuda(x, y, out_dtype=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -294,22 +293,26 @@ def batch_matmul_tensorcore_cuda(x, y): assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent" batch, M, K = x.shape N = y.shape[1] - out_dtype = x.dtype - assert ( - (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) - or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) - or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) - ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)" + if out_dtype is None: + out_dtype = x.dtype - x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i, k].astype("float16")) - y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j, k].astype("float16")) + assert x.dtype == y.dtype + assert x.dtype in ["float16", "uint8", "int8", "uint4", "int4"] + if x.dtype in ["float16", "uint8", "int8"]: + assert ( + (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) + or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) + or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) + ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)" + else: + assert(M % 8 == 0 and K % 32 == 0 and N % 8 == 0), "The shape of (M, K, N) must be multiple of (8, 32, 8)" k = te.reduce_axis((0, K), name="k") return te.compute( (batch, M, N), lambda b, i, j: te.sum( - x_16[b, i, k].astype(out_dtype) * y_16[b, j, k].astype(out_dtype), axis=k + x[b, i, k].astype(out_dtype) * y[b, j, k].astype(out_dtype), axis=k ), tag="batch_matmul_tensorcore", ) diff --git a/python/tvm/topi/cuda/dense_tensorcore.py b/python/tvm/topi/cuda/dense_tensorcore.py index 430f8044528c..d82c522eb3cb 100644 --- a/python/tvm/topi/cuda/dense_tensorcore.py +++ b/python/tvm/topi/cuda/dense_tensorcore.py @@ -60,21 +60,26 @@ def dense_tensorcore_cuda(data, weight, bias=None, out_dtype=None): out_dtype = data.dtype batch, in_dim = get_const_tuple(data.shape) out_dim, _ = get_const_tuple(weight.shape) - assert ( - (batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0) - or (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0) - or (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0) - ), ( - "The shape of (batch, in_dim, out_dim) " - "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" - ) + + assert data.dtype == weight.dtype + assert data.dtype in ["float16", "int8", "uint8", "int4", "uint4"] + if data.dtype in ["float16", "int8", "uint8"]: + assert ( + (batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0) + or (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0) + or (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0) + ), ( + "The shape of (batch, in_dim, out_dim) " + "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" + ) + else: + assert(batch % 8 == 0 and in_dim % 32 == 0 and out_dim % 8 == 0), "The shape of (batch, in_dim, out_dim) must be multiple of (8, 32, 8)" + k = te.reduce_axis((0, in_dim), name="k") - data_16 = te.compute((batch, in_dim), lambda b, i: data[b, i].astype("float16")) - weight_16 = te.compute((out_dim, in_dim), lambda o, i: weight[o, i].astype("float16")) matmul = te.compute( (batch, out_dim), lambda i, j: te.sum( - data_16[i, k].astype(out_dtype) * weight_16[j, k].astype(out_dtype), axis=k + data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k ), name="T_dense", tag="dense_tensorcore", @@ -92,9 +97,8 @@ def _schedule_dense_tensorcore(cfg, s, C): """Schedule dense operator using Tensorcore""" A, B = s[C].op.input_tensors batch, out_dim = get_const_tuple(C.shape) + data_dtype = A.dtype out_dtype = C.dtype - s[A].compute_inline() - s[B].compute_inline() # Explicit memory access AS = s.cache_read(A, "shared", [C]) @@ -127,16 +131,27 @@ def _schedule_dense_tensorcore(cfg, s, C): cfg.define_knob("offsetCS", [0, 8]) cfg.define_knob("vec", [1, 2, 4, 8]) - # Ensure that the default parameters are applicable when autotvm is not in use - if batch % 32 == 0 and out_dim % 8 == 0: - cfg.define_knob("wmma_m", [32, 16, 8]) - elif batch % 16 == 0 and out_dim % 16 == 0: - cfg.define_knob("wmma_m", [16, 8, 32]) - elif batch % 8 == 0 and out_dim % 32 == 0: - cfg.define_knob("wmma_m", [8, 16, 32]) + if data_dtype in ["float16", "int8", "uint8"]: + # Ensure that the default parameters are applicable when autotvm is not in use + if batch % 32 == 0 and out_dim % 8 == 0: + cfg.define_knob("wmma_m", [32, 16, 8]) + elif batch % 16 == 0 and out_dim % 16 == 0: + cfg.define_knob("wmma_m", [16, 8, 32]) + elif batch % 8 == 0 and out_dim % 32 == 0: + cfg.define_knob("wmma_m", [8, 16, 32]) + wmma_k = 16 + wmma_m = cfg["wmma_m"].val + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + else: + wmma_m = wmma_n = 8 + wmma_k = 32 warp_size = 32 - wmma_k = 16 block_row_warps = cfg["block_row_warps"].val block_col_warps = cfg["block_col_warps"].val warp_row_tiles = cfg["warp_row_tiles"].val @@ -144,16 +159,8 @@ def _schedule_dense_tensorcore(cfg, s, C): chunk = cfg["chunk"].val offset = cfg["offset"].val offsetCS = cfg["offsetCS"].val - wmma_m = cfg["wmma_m"].val vec = cfg["vec"].val - if wmma_m == 16: - wmma_n = 16 - elif wmma_m == 8: - wmma_n = 32 - elif wmma_m == 32: - wmma_n = 8 - # Define the stride of intrin functions AS_align = chunk * wmma_k + offset BS_align = chunk * wmma_k + offset @@ -245,10 +252,8 @@ def shared_shedule(stage, strides): shared_shedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) - # TODO: add checking here, datatype casting may cause precision loss - in_dtype = "float16" - AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype) - BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype) + AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype) + BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=data_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm") CL_compute = te.compute( (wmma_m, wmma_n), @@ -264,13 +269,13 @@ def shared_shedule(stage, strides): s[AF].tensorize( b_ii, intrin_wmma_load_matrix_A( - AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), "float16" + AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), data_dtype ), ) s[BF].tensorize( o_ii, intrin_wmma_load_matrix_W( - BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), "float16" + BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), data_dtype ), ) s[CF].tensorize( diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index eb7c71ddf1c9..f6a7757702ee 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -54,14 +54,14 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): # Collect the input exprs. x, y = inputs - # Pad input and output channels to use tensorcore schedule. - if dtype in ["float16"]: # todo: support int8/int4 - B, M, K = x_tensor.shape - B, N, K = y_tensor.shape - M = M.value - K = K.value - N = N.value + B, M, K = x_tensor.shape + B, N, K = y_tensor.shape + M = M.value + K = K.value + N = N.value + # Pad input and output channels to use tensorcore schedule. + if dtype in ["float16", "int8", "uint8"]: # The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) if ( (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) @@ -70,31 +70,38 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): ): # no need to pad return None - candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] - (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N, candidates) - - if extra_flops > 2: - logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops) + elif dtype in ["int4", "uint4"]: + if (M % 8 == 0 and K % 32 == 0 and N % 8 == 0): + # no need to pad return None - logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops) - if dm or dk: - x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) - else: - x_ = x - if dn or dk: - y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) - else: - y_ = y - out_ = relay.nn.batch_matmul(x_, y_) - if dm or dn: - original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape) - else: - out = out_ - return out - return None + candidates = [(8, 32, 8)] + else: + return None + + (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N, candidates) + + if extra_flops > 2: + logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops) + return None + + logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops) + if dm or dk: + x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) + else: + x_ = x + if dn or dk: + y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) + else: + y_ = y + out_ = relay.nn.batch_matmul(x_, y_) + if dm or dn: + original_out_shape = [x.value for x in output_tensor.shape] + out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape) + else: + out = out_ + return out @nn.dense_legalize.register("cuda") @@ -125,18 +132,18 @@ def _dense_legalize(attrs, inputs, arg_types): # Collect the input exprs. x, y = inputs - # Pad input and output channels to use tensorcore schedule. - if dtype in ["float16"]: # todo: support int8/int4 - M, K = x_tensor.shape - N, K = y_tensor.shape - try: - M = M.value - K = K.value - N = N.value - except AttributeError: - # todo: deal with unfixed shape when compiling wdl model - return None + M, K = x_tensor.shape + N, K = y_tensor.shape + try: + M = M.value + K = K.value + N = N.value + except AttributeError: + # todo: deal with unfixed shape when compiling wdl model + return None + # Pad input and output channels to use tensorcore schedule. + if dtype in ["float16", "int8", "uint8"]: # The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) if ( (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) @@ -147,30 +154,37 @@ def _dense_legalize(attrs, inputs, arg_types): return None candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] - (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N, candidates) - - if extra_flops_ratio > 2: - logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio) + elif dtype in ["int4", "uint4"]: + if (M % 8 == 0 and K % 32 == 0 and N % 8 == 0): + # no need to pad return None - - logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio) - - if dm or dk: - x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) - else: - x_ = x - if dn or dk: - y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) - else: - y_ = y - out_ = relay.nn.dense(x_, y_) - if dm or dn: - original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape) - else: - out = out_ - return out - return None + candidates = [(8, 32, 8)] + else: + return None + + (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N, candidates) + + if extra_flops_ratio > 2: + logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio) + return None + + logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio) + + if dm or dk: + x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) + else: + x_ = x + if dn or dk: + y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) + else: + y_ = y + out_ = relay.nn.dense(x_, y_) + if dm or dn: + original_out_shape = [x.value for x in output_tensor.shape] + out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape) + else: + out = out_ + return out def pad_to_tensorcore(M, K, N, candidates): diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index afb251417315..b4490f7ef5ba 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -70,3 +70,4 @@ from .space_to_batch_nd import space_to_batch_nd_python from .batch_to_space_nd import batch_to_space_nd_python from .nll_loss import nll_loss +from .dense import dense diff --git a/python/tvm/topi/testing/dense.py b/python/tvm/topi/testing/dense.py new file mode 100644 index 000000000000..c21b2825b2ac --- /dev/null +++ b/python/tvm/topi/testing/dense.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Dense in python""" +import numpy as np + + +def dense(x, y, bias, use_bias=False, use_relu=False, out_dtype=None): + """dense operator implemented in numpy. + + Parameters + ---------- + x : numpy.ndarray + 2-D with shape [M, K] + + y : numpy.ndarray + 2-D with shape [N, K] + + bias: numpy.ndarray + 1-D with shape [M,] + + out_dtype: string, optional + Specify the dtype of output + + Returns + ------- + out : numpy.ndarray + 2-D with shape [M, N] + """ + dtype = x.dtype if out_dtype is None else out_dtype + if use_bias: + out = np.dot(x.astype(dtype), y.T.astype(dtype)) + bias + else: + out = np.dot(x.astype(dtype), y.T.astype(dtype)) + + if use_relu: + out = np.maximum(out, 0) + + return out diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py index 1312b396fe4c..bcd69f7253ef 100644 --- a/tests/python/relay/test_pass_legalize_tensorcore.py +++ b/tests/python/relay/test_pass_legalize_tensorcore.py @@ -206,7 +206,7 @@ def expected(): @tvm.testing.uses_gpu def test_legalize_dense(): - def _test_legalize_dense(data_shape, kernel_shape, pad_shape, do_pad=True): + def _test_legalize_dense(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): """test legalize dense to enable tensorcore""" M, K = data_shape N, _ = kernel_shape @@ -214,8 +214,8 @@ def _test_legalize_dense(data_shape, kernel_shape, pad_shape, do_pad=True): dm, dk, dn = pad_shape def before(): - x = relay.var("x", shape=data_shape, dtype="float16") - weight = relay.var("weight", shape=kernel_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) + weight = relay.var("weight", shape=kernel_shape, dtype=dtype) y = relay.nn.dense(x, weight) y = relay.Function([x, weight], y) return y @@ -227,12 +227,12 @@ def legalize_dense(attrs, inputs, types): def expected(): if not do_pad: return before() - x = relay.var("x", shape=data_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) if dm or dk: x_pad = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) else: x_pad = x - weight = relay.var("weight", shape=(kernel_shape), dtype="float16") + weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) if dn or dk: weight_pad = relay.nn.pad(weight, pad_width=((0, dn), (0, dk))) else: @@ -255,18 +255,28 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) # dense - _test_legalize_dense((8, 16), (32, 16), (0, 0, 0), False) - _test_legalize_dense((7, 16), (32, 16), (1, 0, 0)) - _test_legalize_dense((8, 15), (32, 15), (0, 1, 0)) - _test_legalize_dense((8, 16), (31, 16), (0, 0, 1)) - _test_legalize_dense((7, 15), (31, 15), (1, 1, 1)) - _test_legalize_dense((3, 16), (32, 16), (5, 0, 0)) - _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), False) + for dtype in ["float16", "int8"]: + _test_legalize_dense((8, 16), (32, 16), (0, 0, 0), dtype, False) + _test_legalize_dense((7, 16), (32, 16), (1, 0, 0), dtype) + _test_legalize_dense((8, 15), (32, 15), (0, 1, 0), dtype) + _test_legalize_dense((8, 16), (31, 16), (0, 0, 1), dtype) + _test_legalize_dense((7, 15), (31, 15), (1, 1, 1), dtype) + _test_legalize_dense((3, 16), (32, 16), (5, 0, 0), dtype) + _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), dtype, False) + + _test_legalize_dense((8, 32), (32, 32), (0, 0, 0), "int4", False) + _test_legalize_dense((7, 32), (32, 32), (1, 0, 0), "int4") + _test_legalize_dense((8, 31), (32, 31), (0, 1, 0), "int4") + _test_legalize_dense((8, 32), (31, 32), (0, 0, 1), "int4") + _test_legalize_dense((7, 31), (31, 31), (1, 1, 1), "int4") + _test_legalize_dense((3, 32), (32, 32), (5, 0, 0), "int4") + _test_legalize_dense((8, 16), (32, 16), (0, 16, 0), "int4") + _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), "int4", False) @tvm.testing.uses_gpu def test_legalize_batch_matmul(): - def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, do_pad=True): + def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): """test legalize dense to enable tensorcore""" B, M, _ = data_shape _, N, _ = kernel_shape @@ -274,8 +284,8 @@ def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, do_pad=True dm, dk, dn = pad_shape def before(): - x = relay.var("x", shape=data_shape, dtype="float16") - weight = relay.var("weight", shape=kernel_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) + weight = relay.var("weight", shape=kernel_shape, dtype=dtype) y = relay.nn.batch_matmul(x, weight) y = relay.Function([x, weight], y) return y @@ -287,12 +297,12 @@ def legalize_batch_matmul(attrs, inputs, types): def expected(): if not do_pad: return before() - x = relay.var("x", shape=data_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) if dm or dk: x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) else: x_pad = x - weight = relay.var("weight", shape=(kernel_shape), dtype="float16") + weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) if dn or dk: weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk))) else: @@ -314,13 +324,23 @@ def expected(): b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) - _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 0, 0), False) - _test_legalize_batch_matmul((16, 7, 16), (16, 32, 16), (1, 0, 0)) - _test_legalize_batch_matmul((16, 8, 15), (16, 32, 15), (0, 1, 0)) - _test_legalize_batch_matmul((16, 8, 16), (16, 31, 16), (0, 0, 1)) - _test_legalize_batch_matmul((16, 7, 15), (16, 31, 15), (1, 1, 1)) - _test_legalize_batch_matmul((16, 3, 16), (16, 32, 16), (5, 0, 0)) - _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), False) + for dtype in ["float16", "int8"]: + _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 0, 0), dtype, False) + _test_legalize_batch_matmul((16, 7, 16), (16, 32, 16), (1, 0, 0), dtype) + _test_legalize_batch_matmul((16, 8, 15), (16, 32, 15), (0, 1, 0), dtype) + _test_legalize_batch_matmul((16, 8, 16), (16, 31, 16), (0, 0, 1), dtype) + _test_legalize_batch_matmul((16, 7, 15), (16, 31, 15), (1, 1, 1), dtype) + _test_legalize_batch_matmul((16, 3, 16), (16, 32, 16), (5, 0, 0), dtype) + _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), dtype, False) + + _test_legalize_batch_matmul((16, 8, 32), (16, 32, 32), (0, 0, 0), "int4", False) + _test_legalize_batch_matmul((16, 7, 32), (16, 32, 32), (1, 0, 0), "int4") + _test_legalize_batch_matmul((16, 8, 31), (16, 32, 31), (0, 1, 0), "int4") + _test_legalize_batch_matmul((16, 8, 32), (16, 31, 32), (0, 0, 1), "int4") + _test_legalize_batch_matmul((16, 7, 31), (16, 31, 31), (1, 1, 1), "int4") + _test_legalize_batch_matmul((16, 3, 32), (16, 32, 32), (5, 0, 0), "int4") + _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 16, 0), "int4") + _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), "int4", False) if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py index 31a7e85113ab..5def380ba1b1 100644 --- a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py +++ b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py @@ -29,34 +29,71 @@ "gpu": (topi.cuda.batch_matmul_tensorcore, topi.cuda.schedule_batch_matmul_tensorcore), } +def convert_int32_into_int4(a_int32): + """convert int32 values into int4 + Parameters + ---------- + a_int32 : int -def verify_batch_matmul(x_batch, y_batch, M, N, K): - x = te.placeholder((x_batch, M, K), name="x") - y = te.placeholder((y_batch, N, K), name="y") - dtype = x.dtype + Return + ------ + a_int4 : int + """ + B, K, L = a_int32.shape + assert L % 8 == 0 + a_int4 = np.zeros(shape=(B, K, L // 8), dtype=np.int32) + for b in range(B): + for k in range(K): + for l in range(L // 8): + for m in range(min(8, L - l * 8)): + a_int4[b, k, l] = a_int4[b, k, l] | ( + (a_int32[b, k, l * 8 + m] & 0xF) << ((7 - m) * 4) + ) + return a_int4 + + +def verify_batch_matmul(x_batch, y_batch, M, N, K, dtype): + x = te.placeholder((x_batch, M, K), name="x", dtype=dtype) + y = te.placeholder((y_batch, N, K), name="y", dtype=dtype) + + assert dtype in ["int4", "int8", "float16"] + + out_dtype = "float32" + if dtype in ["int8", "int4"]: + out_dtype = "int32" # use memoize to pickle the test data for next time use @memoize("topi.tests.test_topi_batch_matmul_tensorcore") def get_ref_data(): - a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype) - b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype) - c_np = tvm.topi.testing.batch_matmul(a_np, b_np) + if dtype == "int4": + a_np = np.random.randint(low=-8, high=7, size=(x_batch, M, K)) + b_np = np.random.randint(low=-8, high=7, size=(y_batch, N, K)) + elif dtype == "int8": + a_np = np.random.randint(low=-128, high=127, size=(x_batch, M, K)).astype(dtype) + b_np = np.random.randint(low=-128, high=127, size=(y_batch, N, K)).astype(dtype) + else: + a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype) + b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype) + c_np = tvm.topi.testing.batch_matmul(a_np, b_np, out_dtype) return (a_np, b_np, c_np) # get the test data a_np, b_np, c_np = get_ref_data() + if dtype == "int4": + a_np = convert_int32_into_int4(a_np) + b_np = convert_int32_into_int4(b_np) def check_device(device): dev = tvm.device(device, 0) print("Running on target: %s" % device) with tvm.target.Target(device): fcompute, fschedule = tvm.topi.testing.dispatch(device, _batch_matmul_implement) - out = fcompute(x, y) + out = fcompute(x, y, None, out_dtype) s = fschedule([out]) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=dtype), dev) - f = tvm.build(s, [x, y, out], device, name="dense") + c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out_dtype), dev) + f = tvm.build(s, [x, y, out], device, name="batch_matmul") f(a, b, c) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) @@ -65,10 +102,12 @@ def check_device(device): @tvm.testing.requires_tensorcore def test_batch_matmul(): - verify_batch_matmul(1, 1, 16, 16, 32) - verify_batch_matmul(5, 5, 16, 16, 32) - verify_batch_matmul(5, 5, 16, 32, 32) - verify_batch_matmul(30, 30, 16, 32, 32) + for dtype in ["float16", "int8", "int4"]: + print(dtype) + verify_batch_matmul(1, 1, 16, 16, 32, dtype) + verify_batch_matmul(5, 5, 16, 16, 32, dtype) + verify_batch_matmul(5, 5, 16, 32, 32, dtype) + verify_batch_matmul(30, 30, 16, 32, 32, dtype) if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_dense_tensorcore.py b/tests/python/topi/python/test_topi_dense_tensorcore.py index a3657af2c1ca..1b0a7490ce8e 100644 --- a/tests/python/topi/python/test_topi_dense_tensorcore.py +++ b/tests/python/topi/python/test_topi_dense_tensorcore.py @@ -29,40 +29,98 @@ _dense_implement = {"gpu": [(topi.cuda.dense_tensorcore, topi.cuda.schedule_dense_tensorcore)]} -def verify_dense(batch, in_dim, out_dim, use_bias=True): +def convert_int32_into_int4(a_int32): + """convert int32 values into int4 + Parameters + ---------- + a_int32 : int + + Return + ------ + a_int4 : int + """ + K, L = a_int32.shape + assert L % 8 == 0 + a_int4 = np.zeros(shape=(K, L // 8), dtype=np.int32) + for k in range(K): + for l in range(L // 8): + for m in range(min(8, L - l * 8)): + a_int4[k, l] = a_int4[k, l] | ( + (a_int32[k, l * 8 + m] & 0xF) << ((7 - m) * 4) + ) + return a_int4 + + +def convert_int32_into_int4_bias(a_int32): + """convert int32 values into int4 + Parameters + ---------- + a_int32 : int + + Return + ------ + a_int4 : int + """ + L, = a_int32.shape + assert L % 8 == 0 + a_int4 = np.zeros(shape=(L // 8), dtype=np.int32) + for l in range(L // 8): + for m in range(min(8, L - l * 8)): + a_int4[l] = a_int4[l] | ( + (a_int32[l * 8 + m] & 0xF) << ((7 - m) * 4) + ) + return a_int4 + + +def verify_dense(batch, in_dim, out_dim, dtype, use_bias=True): """Dense tensorcore verify function""" - A = te.placeholder((batch, in_dim), name="A") - B = te.placeholder((out_dim, in_dim), name="B") - C = te.placeholder((out_dim,), name="C") - dtype = A.dtype + A = te.placeholder((batch, in_dim), name="A", dtype=dtype) + B = te.placeholder((out_dim, in_dim), name="B", dtype=dtype) + C = te.placeholder((out_dim,), name="C", dtype=dtype) + + assert dtype in ["int4", "int8", "float16"] + + out_dtype = "float32" + if dtype in ["int8", "int4"]: + out_dtype = "int32" # use memoize to pickle the test data for next time use @memoize("topi.tests.test_topi_dense_tensorcore") def get_ref_data(): - a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) - b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) - c_np = np.random.uniform(size=(out_dim,)).astype(dtype) - if use_bias: - d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0) + if dtype == "int4": + a_np = np.random.randint(low=-8, high=7, size=(batch, in_dim)) + b_np = np.random.randint(low=-8, high=7, size=(out_dim, in_dim)) + c_np = np.random.randint(low=-8, high=7, size=(out_dim,)) + elif dtype == "int8": + a_np = np.random.randint(low=-128, high=127, size=(batch, in_dim)).astype(dtype) + b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(dtype) + c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(dtype) else: - d_np = np.maximum(np.dot(a_np, b_np.T), 0.0) + a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) + b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) + c_np = np.random.uniform(size=(out_dim,)).astype(dtype) + d_np = tvm.topi.testing.dense(a_np, b_np, c_np, use_bias, True, out_dtype) return (a_np, b_np, c_np, d_np) # get the test data a_np, b_np, c_np, d_np = get_ref_data() + if dtype == "int4": + a_np = convert_int32_into_int4(a_np) + b_np = convert_int32_into_int4(b_np) + c_np = convert_int32_into_int4_bias(c_np) def check_device(device): dev = tvm.device(device, 0) print("Running on target: %s" % device) for fcompute, fschedule in tvm.topi.testing.dispatch(device, _dense_implement): with tvm.target.Target(device): - D = fcompute(A, B, C if use_bias else None) + D = fcompute(A, B, C if use_bias else None, out_dtype) D = topi.nn.relu(D) s = fschedule([D]) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(c_np, dev) - d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), dev) + d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), dev) f = tvm.build(s, [A, B, C, D], device, name="dense") f(a, b, c, d) tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-3) @@ -73,11 +131,17 @@ def check_device(device): @tvm.testing.requires_tensorcore def test_dense_tensorcore(): """Test cases""" - verify_dense(8, 16, 32, use_bias=True) - verify_dense(16, 32, 16, use_bias=True) - verify_dense(256, 1024, 1024, use_bias=True) - verify_dense(1000, 1024, 1024, use_bias=False) - verify_dense(256, 2048, 1000, use_bias=False) + for dtype in ["float16", "int8"]: + verify_dense(8, 16, 32, "float16", use_bias=True) + verify_dense(16, 32, 16, dtype, use_bias=True) + verify_dense(256, 1024, 1024, dtype, use_bias=True) + verify_dense(1000, 1024, 1024, dtype, use_bias=False) + verify_dense(256, 2048, 1000, dtype, use_bias=False) + #TODO: need fix int4 use_bias=True, wyc-ruiker + verify_dense(16, 32, 16, "int4", use_bias=False) + verify_dense(256, 1024, 1024, "int4", use_bias=False) + verify_dense(1000, 1024, 1024, "int4", use_bias=False) + verify_dense(256, 2048, 1000, "int4", use_bias=False) if __name__ == "__main__": From 500443d3ad58955515ec82963050786964f1152f Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Fri, 2 Jul 2021 20:18:38 +0800 Subject: [PATCH 02/11] fix bug --- python/tvm/topi/cuda/tensorcore_alter_op.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index f6a7757702ee..4b5c4e99186d 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -95,7 +95,7 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) else: y_ = y - out_ = relay.nn.batch_matmul(x_, y_) + out_ = relay.nn.batch_matmul(x_, y_, attrs.out_dtype) if dm or dn: original_out_shape = [x.value for x in output_tensor.shape] out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape) @@ -122,6 +122,7 @@ def _dense_legalize(attrs, inputs, arg_types): result : tvm.relay.Expr The legalized expr """ + new_attrs = {k: attrs[k] for k in attrs.keys()} # Collect the input tensors. x_tensor, y_tensor = arg_types[0], arg_types[1] dtype = x_tensor.dtype @@ -178,7 +179,7 @@ def _dense_legalize(attrs, inputs, arg_types): y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) else: y_ = y - out_ = relay.nn.dense(x_, y_) + out_ = relay.nn.dense(x_, y_, **new_attrs) if dm or dn: original_out_shape = [x.value for x in output_tensor.shape] out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape) From 0cb64ffb838076375b08231ec18c0ce82f4e3036 Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Fri, 2 Jul 2021 20:25:40 +0800 Subject: [PATCH 03/11] fix --- tests/python/topi/python/test_topi_batch_matmul_tensorcore.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py index 5def380ba1b1..971475d8ec63 100644 --- a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py +++ b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py @@ -103,7 +103,6 @@ def check_device(device): @tvm.testing.requires_tensorcore def test_batch_matmul(): for dtype in ["float16", "int8", "int4"]: - print(dtype) verify_batch_matmul(1, 1, 16, 16, 32, dtype) verify_batch_matmul(5, 5, 16, 16, 32, dtype) verify_batch_matmul(5, 5, 16, 32, 32, dtype) From 96529475d4b96330495d81cf06c7759fe15eb937 Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Fri, 2 Jul 2021 17:02:14 +0800 Subject: [PATCH 04/11] add int8/int tensorcore for dense/batch_matmul --- python/tvm/relay/op/strategy/cuda.py | 21 ++- .../tvm/topi/cuda/batch_matmul_tensorcore.py | 75 +++++----- python/tvm/topi/cuda/dense_tensorcore.py | 75 +++++----- python/tvm/topi/cuda/tensorcore_alter_op.py | 140 ++++++++++-------- python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/dense.py | 53 +++++++ .../relay/test_pass_legalize_tensorcore.py | 68 ++++++--- .../test_topi_batch_matmul_tensorcore.py | 67 +++++++-- .../topi/python/test_topi_dense_tensorcore.py | 100 ++++++++++--- 9 files changed, 405 insertions(+), 195 deletions(-) create mode 100644 python/tvm/topi/testing/dense.py diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index aeeb62af11a9..c9c611e5f631 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -844,13 +844,24 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): x, y = inputs _, M, K = get_const_tuple(x.shape) _, N, K = get_const_tuple(y.shape) - if x.dtype in ["float16", "int8", "uint8"] and ( - (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) - or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) - or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) + if ( + ( + x.dtype in ["float16", "int8", "uint8"] + and ( + (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) + or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) + or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) + ) + ) + or ( + x.dtype in ["int4", "uint4"] + and K % 32 == 0 + and M % 8 == 0 + and N % 8 == 0 + ) ): strategy.add_implementation( - wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore), + wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore, need_out_dtype=True), wrap_topi_schedule(topi.cuda.schedule_batch_matmul_tensorcore), name="batch_matmul_tensorcore.cuda", plevel=20, diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index 962a8af7853b..67dd6d8c892e 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -29,10 +29,10 @@ @autotvm.register_topi_compute("batch_matmul_tensorcore.cuda") -def batch_matmul_tensorcore(cfg, x, y, out_shape=None): +def batch_matmul_tensorcore(cfg, x, y, out_shape=None, out_dtype=None): """batch matmul tensorcore operator on cuda""" # todo: deal with out_shape for broadcast, liuxin.ai - return batch_matmul_tensorcore_cuda(x, y) + return batch_matmul_tensorcore_cuda(x, y, out_dtype) @autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda") @@ -57,10 +57,8 @@ def _schedule(cfg, s, C): A, B = s[C].op.input_tensors batch, m_dim, k_dim = get_const_tuple(A.shape) batch, n_dim, k_dim = get_const_tuple(B.shape) + data_dtype = A.dtype out_dtype = C.dtype - # inline astype fp16 - s[A].compute_inline() - s[B].compute_inline() # Explicit memory access AS = s.cache_read(A, "shared", [C]) @@ -94,15 +92,26 @@ def _schedule(cfg, s, C): cfg.define_knob("vec", [1, 2, 4, 8]) # Ensure that the default parameters are applicable when autotvm is not in use - if m_dim % 32 == 0 and n_dim % 8 == 0: - cfg.define_knob("wmma_m", [32, 16, 8]) - elif m_dim % 16 == 0 and n_dim % 16 == 0: - cfg.define_knob("wmma_m", [16, 8, 32]) - elif m_dim % 8 == 0 and n_dim % 32 == 0: - cfg.define_knob("wmma_m", [8, 16, 32]) + if data_dtype in ["float16", "uint8", "int8"]: + if m_dim % 32 == 0 and n_dim % 8 == 0: + cfg.define_knob("wmma_m", [32, 16, 8]) + elif m_dim % 16 == 0 and n_dim % 16 == 0: + cfg.define_knob("wmma_m", [16, 8, 32]) + elif m_dim % 8 == 0 and n_dim % 32 == 0: + cfg.define_knob("wmma_m", [8, 16, 32]) + wmma_k = 16 + wmma_m = cfg["wmma_m"].val + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + else: + wmma_m = wmma_n = 8 + wmma_k = 32 warp_size = 32 - wmma_k = 16 block_row_warps = cfg["block_row_warps"].val block_col_warps = cfg["block_col_warps"].val warp_row_tiles = cfg["warp_row_tiles"].val @@ -110,16 +119,8 @@ def _schedule(cfg, s, C): chunk = cfg["chunk"].val offset = cfg["offset"].val offsetCS = cfg["offsetCS"].val - wmma_m = cfg["wmma_m"].val vec = cfg["vec"].val - if wmma_m == 16: - wmma_n = 16 - elif wmma_m == 8: - wmma_n = 32 - elif wmma_m == 32: - wmma_n = 8 - # Define the stride of intrin functions AS_align = chunk * wmma_k + offset BS_align = chunk * wmma_k + offset @@ -211,10 +212,8 @@ def shared_shedule(stage, strides): shared_shedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) - # TODO: add checking here, datatype casting may cause precision loss - in_dtype = "float16" - AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype) - BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype) + AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype) + BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=data_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm") CL_compute = te.compute( (wmma_m, wmma_n), @@ -236,7 +235,7 @@ def shared_shedule(stage, strides): "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), - "float16", + data_dtype, ), ) s[BF].tensorize( @@ -248,7 +247,7 @@ def shared_shedule(stage, strides): "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), - "float16", + data_dtype, ), ) s[CF].tensorize( @@ -270,7 +269,7 @@ def _callback(op): return s -def batch_matmul_tensorcore_cuda(x, y): +def batch_matmul_tensorcore_cuda(x, y, out_dtype=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -294,22 +293,26 @@ def batch_matmul_tensorcore_cuda(x, y): assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent" batch, M, K = x.shape N = y.shape[1] - out_dtype = x.dtype - assert ( - (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) - or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) - or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) - ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)" + if out_dtype is None: + out_dtype = x.dtype - x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i, k].astype("float16")) - y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j, k].astype("float16")) + assert x.dtype == y.dtype + assert x.dtype in ["float16", "uint8", "int8", "uint4", "int4"] + if x.dtype in ["float16", "uint8", "int8"]: + assert ( + (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) + or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) + or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) + ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)" + else: + assert(M % 8 == 0 and K % 32 == 0 and N % 8 == 0), "The shape of (M, K, N) must be multiple of (8, 32, 8)" k = te.reduce_axis((0, K), name="k") return te.compute( (batch, M, N), lambda b, i, j: te.sum( - x_16[b, i, k].astype(out_dtype) * y_16[b, j, k].astype(out_dtype), axis=k + x[b, i, k].astype(out_dtype) * y[b, j, k].astype(out_dtype), axis=k ), tag="batch_matmul_tensorcore", ) diff --git a/python/tvm/topi/cuda/dense_tensorcore.py b/python/tvm/topi/cuda/dense_tensorcore.py index 430f8044528c..d82c522eb3cb 100644 --- a/python/tvm/topi/cuda/dense_tensorcore.py +++ b/python/tvm/topi/cuda/dense_tensorcore.py @@ -60,21 +60,26 @@ def dense_tensorcore_cuda(data, weight, bias=None, out_dtype=None): out_dtype = data.dtype batch, in_dim = get_const_tuple(data.shape) out_dim, _ = get_const_tuple(weight.shape) - assert ( - (batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0) - or (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0) - or (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0) - ), ( - "The shape of (batch, in_dim, out_dim) " - "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" - ) + + assert data.dtype == weight.dtype + assert data.dtype in ["float16", "int8", "uint8", "int4", "uint4"] + if data.dtype in ["float16", "int8", "uint8"]: + assert ( + (batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0) + or (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0) + or (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0) + ), ( + "The shape of (batch, in_dim, out_dim) " + "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" + ) + else: + assert(batch % 8 == 0 and in_dim % 32 == 0 and out_dim % 8 == 0), "The shape of (batch, in_dim, out_dim) must be multiple of (8, 32, 8)" + k = te.reduce_axis((0, in_dim), name="k") - data_16 = te.compute((batch, in_dim), lambda b, i: data[b, i].astype("float16")) - weight_16 = te.compute((out_dim, in_dim), lambda o, i: weight[o, i].astype("float16")) matmul = te.compute( (batch, out_dim), lambda i, j: te.sum( - data_16[i, k].astype(out_dtype) * weight_16[j, k].astype(out_dtype), axis=k + data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k ), name="T_dense", tag="dense_tensorcore", @@ -92,9 +97,8 @@ def _schedule_dense_tensorcore(cfg, s, C): """Schedule dense operator using Tensorcore""" A, B = s[C].op.input_tensors batch, out_dim = get_const_tuple(C.shape) + data_dtype = A.dtype out_dtype = C.dtype - s[A].compute_inline() - s[B].compute_inline() # Explicit memory access AS = s.cache_read(A, "shared", [C]) @@ -127,16 +131,27 @@ def _schedule_dense_tensorcore(cfg, s, C): cfg.define_knob("offsetCS", [0, 8]) cfg.define_knob("vec", [1, 2, 4, 8]) - # Ensure that the default parameters are applicable when autotvm is not in use - if batch % 32 == 0 and out_dim % 8 == 0: - cfg.define_knob("wmma_m", [32, 16, 8]) - elif batch % 16 == 0 and out_dim % 16 == 0: - cfg.define_knob("wmma_m", [16, 8, 32]) - elif batch % 8 == 0 and out_dim % 32 == 0: - cfg.define_knob("wmma_m", [8, 16, 32]) + if data_dtype in ["float16", "int8", "uint8"]: + # Ensure that the default parameters are applicable when autotvm is not in use + if batch % 32 == 0 and out_dim % 8 == 0: + cfg.define_knob("wmma_m", [32, 16, 8]) + elif batch % 16 == 0 and out_dim % 16 == 0: + cfg.define_knob("wmma_m", [16, 8, 32]) + elif batch % 8 == 0 and out_dim % 32 == 0: + cfg.define_knob("wmma_m", [8, 16, 32]) + wmma_k = 16 + wmma_m = cfg["wmma_m"].val + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + else: + wmma_m = wmma_n = 8 + wmma_k = 32 warp_size = 32 - wmma_k = 16 block_row_warps = cfg["block_row_warps"].val block_col_warps = cfg["block_col_warps"].val warp_row_tiles = cfg["warp_row_tiles"].val @@ -144,16 +159,8 @@ def _schedule_dense_tensorcore(cfg, s, C): chunk = cfg["chunk"].val offset = cfg["offset"].val offsetCS = cfg["offsetCS"].val - wmma_m = cfg["wmma_m"].val vec = cfg["vec"].val - if wmma_m == 16: - wmma_n = 16 - elif wmma_m == 8: - wmma_n = 32 - elif wmma_m == 32: - wmma_n = 8 - # Define the stride of intrin functions AS_align = chunk * wmma_k + offset BS_align = chunk * wmma_k + offset @@ -245,10 +252,8 @@ def shared_shedule(stage, strides): shared_shedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) - # TODO: add checking here, datatype casting may cause precision loss - in_dtype = "float16" - AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype) - BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype) + AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype) + BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=data_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm") CL_compute = te.compute( (wmma_m, wmma_n), @@ -264,13 +269,13 @@ def shared_shedule(stage, strides): s[AF].tensorize( b_ii, intrin_wmma_load_matrix_A( - AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), "float16" + AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), data_dtype ), ) s[BF].tensorize( o_ii, intrin_wmma_load_matrix_W( - BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), "float16" + BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), data_dtype ), ) s[CF].tensorize( diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index eb7c71ddf1c9..f6a7757702ee 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -54,14 +54,14 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): # Collect the input exprs. x, y = inputs - # Pad input and output channels to use tensorcore schedule. - if dtype in ["float16"]: # todo: support int8/int4 - B, M, K = x_tensor.shape - B, N, K = y_tensor.shape - M = M.value - K = K.value - N = N.value + B, M, K = x_tensor.shape + B, N, K = y_tensor.shape + M = M.value + K = K.value + N = N.value + # Pad input and output channels to use tensorcore schedule. + if dtype in ["float16", "int8", "uint8"]: # The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) if ( (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) @@ -70,31 +70,38 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): ): # no need to pad return None - candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] - (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N, candidates) - - if extra_flops > 2: - logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops) + elif dtype in ["int4", "uint4"]: + if (M % 8 == 0 and K % 32 == 0 and N % 8 == 0): + # no need to pad return None - logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops) - if dm or dk: - x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) - else: - x_ = x - if dn or dk: - y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) - else: - y_ = y - out_ = relay.nn.batch_matmul(x_, y_) - if dm or dn: - original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape) - else: - out = out_ - return out - return None + candidates = [(8, 32, 8)] + else: + return None + + (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N, candidates) + + if extra_flops > 2: + logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops) + return None + + logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops) + if dm or dk: + x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) + else: + x_ = x + if dn or dk: + y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) + else: + y_ = y + out_ = relay.nn.batch_matmul(x_, y_) + if dm or dn: + original_out_shape = [x.value for x in output_tensor.shape] + out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape) + else: + out = out_ + return out @nn.dense_legalize.register("cuda") @@ -125,18 +132,18 @@ def _dense_legalize(attrs, inputs, arg_types): # Collect the input exprs. x, y = inputs - # Pad input and output channels to use tensorcore schedule. - if dtype in ["float16"]: # todo: support int8/int4 - M, K = x_tensor.shape - N, K = y_tensor.shape - try: - M = M.value - K = K.value - N = N.value - except AttributeError: - # todo: deal with unfixed shape when compiling wdl model - return None + M, K = x_tensor.shape + N, K = y_tensor.shape + try: + M = M.value + K = K.value + N = N.value + except AttributeError: + # todo: deal with unfixed shape when compiling wdl model + return None + # Pad input and output channels to use tensorcore schedule. + if dtype in ["float16", "int8", "uint8"]: # The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) if ( (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) @@ -147,30 +154,37 @@ def _dense_legalize(attrs, inputs, arg_types): return None candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] - (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N, candidates) - - if extra_flops_ratio > 2: - logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio) + elif dtype in ["int4", "uint4"]: + if (M % 8 == 0 and K % 32 == 0 and N % 8 == 0): + # no need to pad return None - - logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio) - - if dm or dk: - x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) - else: - x_ = x - if dn or dk: - y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) - else: - y_ = y - out_ = relay.nn.dense(x_, y_) - if dm or dn: - original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape) - else: - out = out_ - return out - return None + candidates = [(8, 32, 8)] + else: + return None + + (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N, candidates) + + if extra_flops_ratio > 2: + logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio) + return None + + logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio) + + if dm or dk: + x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) + else: + x_ = x + if dn or dk: + y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) + else: + y_ = y + out_ = relay.nn.dense(x_, y_) + if dm or dn: + original_out_shape = [x.value for x in output_tensor.shape] + out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape) + else: + out = out_ + return out def pad_to_tensorcore(M, K, N, candidates): diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index afb251417315..b4490f7ef5ba 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -70,3 +70,4 @@ from .space_to_batch_nd import space_to_batch_nd_python from .batch_to_space_nd import batch_to_space_nd_python from .nll_loss import nll_loss +from .dense import dense diff --git a/python/tvm/topi/testing/dense.py b/python/tvm/topi/testing/dense.py new file mode 100644 index 000000000000..c21b2825b2ac --- /dev/null +++ b/python/tvm/topi/testing/dense.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Dense in python""" +import numpy as np + + +def dense(x, y, bias, use_bias=False, use_relu=False, out_dtype=None): + """dense operator implemented in numpy. + + Parameters + ---------- + x : numpy.ndarray + 2-D with shape [M, K] + + y : numpy.ndarray + 2-D with shape [N, K] + + bias: numpy.ndarray + 1-D with shape [M,] + + out_dtype: string, optional + Specify the dtype of output + + Returns + ------- + out : numpy.ndarray + 2-D with shape [M, N] + """ + dtype = x.dtype if out_dtype is None else out_dtype + if use_bias: + out = np.dot(x.astype(dtype), y.T.astype(dtype)) + bias + else: + out = np.dot(x.astype(dtype), y.T.astype(dtype)) + + if use_relu: + out = np.maximum(out, 0) + + return out diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py index 1312b396fe4c..bcd69f7253ef 100644 --- a/tests/python/relay/test_pass_legalize_tensorcore.py +++ b/tests/python/relay/test_pass_legalize_tensorcore.py @@ -206,7 +206,7 @@ def expected(): @tvm.testing.uses_gpu def test_legalize_dense(): - def _test_legalize_dense(data_shape, kernel_shape, pad_shape, do_pad=True): + def _test_legalize_dense(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): """test legalize dense to enable tensorcore""" M, K = data_shape N, _ = kernel_shape @@ -214,8 +214,8 @@ def _test_legalize_dense(data_shape, kernel_shape, pad_shape, do_pad=True): dm, dk, dn = pad_shape def before(): - x = relay.var("x", shape=data_shape, dtype="float16") - weight = relay.var("weight", shape=kernel_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) + weight = relay.var("weight", shape=kernel_shape, dtype=dtype) y = relay.nn.dense(x, weight) y = relay.Function([x, weight], y) return y @@ -227,12 +227,12 @@ def legalize_dense(attrs, inputs, types): def expected(): if not do_pad: return before() - x = relay.var("x", shape=data_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) if dm or dk: x_pad = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) else: x_pad = x - weight = relay.var("weight", shape=(kernel_shape), dtype="float16") + weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) if dn or dk: weight_pad = relay.nn.pad(weight, pad_width=((0, dn), (0, dk))) else: @@ -255,18 +255,28 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) # dense - _test_legalize_dense((8, 16), (32, 16), (0, 0, 0), False) - _test_legalize_dense((7, 16), (32, 16), (1, 0, 0)) - _test_legalize_dense((8, 15), (32, 15), (0, 1, 0)) - _test_legalize_dense((8, 16), (31, 16), (0, 0, 1)) - _test_legalize_dense((7, 15), (31, 15), (1, 1, 1)) - _test_legalize_dense((3, 16), (32, 16), (5, 0, 0)) - _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), False) + for dtype in ["float16", "int8"]: + _test_legalize_dense((8, 16), (32, 16), (0, 0, 0), dtype, False) + _test_legalize_dense((7, 16), (32, 16), (1, 0, 0), dtype) + _test_legalize_dense((8, 15), (32, 15), (0, 1, 0), dtype) + _test_legalize_dense((8, 16), (31, 16), (0, 0, 1), dtype) + _test_legalize_dense((7, 15), (31, 15), (1, 1, 1), dtype) + _test_legalize_dense((3, 16), (32, 16), (5, 0, 0), dtype) + _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), dtype, False) + + _test_legalize_dense((8, 32), (32, 32), (0, 0, 0), "int4", False) + _test_legalize_dense((7, 32), (32, 32), (1, 0, 0), "int4") + _test_legalize_dense((8, 31), (32, 31), (0, 1, 0), "int4") + _test_legalize_dense((8, 32), (31, 32), (0, 0, 1), "int4") + _test_legalize_dense((7, 31), (31, 31), (1, 1, 1), "int4") + _test_legalize_dense((3, 32), (32, 32), (5, 0, 0), "int4") + _test_legalize_dense((8, 16), (32, 16), (0, 16, 0), "int4") + _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), "int4", False) @tvm.testing.uses_gpu def test_legalize_batch_matmul(): - def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, do_pad=True): + def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): """test legalize dense to enable tensorcore""" B, M, _ = data_shape _, N, _ = kernel_shape @@ -274,8 +284,8 @@ def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, do_pad=True dm, dk, dn = pad_shape def before(): - x = relay.var("x", shape=data_shape, dtype="float16") - weight = relay.var("weight", shape=kernel_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) + weight = relay.var("weight", shape=kernel_shape, dtype=dtype) y = relay.nn.batch_matmul(x, weight) y = relay.Function([x, weight], y) return y @@ -287,12 +297,12 @@ def legalize_batch_matmul(attrs, inputs, types): def expected(): if not do_pad: return before() - x = relay.var("x", shape=data_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) if dm or dk: x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) else: x_pad = x - weight = relay.var("weight", shape=(kernel_shape), dtype="float16") + weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) if dn or dk: weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk))) else: @@ -314,13 +324,23 @@ def expected(): b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) - _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 0, 0), False) - _test_legalize_batch_matmul((16, 7, 16), (16, 32, 16), (1, 0, 0)) - _test_legalize_batch_matmul((16, 8, 15), (16, 32, 15), (0, 1, 0)) - _test_legalize_batch_matmul((16, 8, 16), (16, 31, 16), (0, 0, 1)) - _test_legalize_batch_matmul((16, 7, 15), (16, 31, 15), (1, 1, 1)) - _test_legalize_batch_matmul((16, 3, 16), (16, 32, 16), (5, 0, 0)) - _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), False) + for dtype in ["float16", "int8"]: + _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 0, 0), dtype, False) + _test_legalize_batch_matmul((16, 7, 16), (16, 32, 16), (1, 0, 0), dtype) + _test_legalize_batch_matmul((16, 8, 15), (16, 32, 15), (0, 1, 0), dtype) + _test_legalize_batch_matmul((16, 8, 16), (16, 31, 16), (0, 0, 1), dtype) + _test_legalize_batch_matmul((16, 7, 15), (16, 31, 15), (1, 1, 1), dtype) + _test_legalize_batch_matmul((16, 3, 16), (16, 32, 16), (5, 0, 0), dtype) + _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), dtype, False) + + _test_legalize_batch_matmul((16, 8, 32), (16, 32, 32), (0, 0, 0), "int4", False) + _test_legalize_batch_matmul((16, 7, 32), (16, 32, 32), (1, 0, 0), "int4") + _test_legalize_batch_matmul((16, 8, 31), (16, 32, 31), (0, 1, 0), "int4") + _test_legalize_batch_matmul((16, 8, 32), (16, 31, 32), (0, 0, 1), "int4") + _test_legalize_batch_matmul((16, 7, 31), (16, 31, 31), (1, 1, 1), "int4") + _test_legalize_batch_matmul((16, 3, 32), (16, 32, 32), (5, 0, 0), "int4") + _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 16, 0), "int4") + _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), "int4", False) if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py index 31a7e85113ab..5def380ba1b1 100644 --- a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py +++ b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py @@ -29,34 +29,71 @@ "gpu": (topi.cuda.batch_matmul_tensorcore, topi.cuda.schedule_batch_matmul_tensorcore), } +def convert_int32_into_int4(a_int32): + """convert int32 values into int4 + Parameters + ---------- + a_int32 : int -def verify_batch_matmul(x_batch, y_batch, M, N, K): - x = te.placeholder((x_batch, M, K), name="x") - y = te.placeholder((y_batch, N, K), name="y") - dtype = x.dtype + Return + ------ + a_int4 : int + """ + B, K, L = a_int32.shape + assert L % 8 == 0 + a_int4 = np.zeros(shape=(B, K, L // 8), dtype=np.int32) + for b in range(B): + for k in range(K): + for l in range(L // 8): + for m in range(min(8, L - l * 8)): + a_int4[b, k, l] = a_int4[b, k, l] | ( + (a_int32[b, k, l * 8 + m] & 0xF) << ((7 - m) * 4) + ) + return a_int4 + + +def verify_batch_matmul(x_batch, y_batch, M, N, K, dtype): + x = te.placeholder((x_batch, M, K), name="x", dtype=dtype) + y = te.placeholder((y_batch, N, K), name="y", dtype=dtype) + + assert dtype in ["int4", "int8", "float16"] + + out_dtype = "float32" + if dtype in ["int8", "int4"]: + out_dtype = "int32" # use memoize to pickle the test data for next time use @memoize("topi.tests.test_topi_batch_matmul_tensorcore") def get_ref_data(): - a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype) - b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype) - c_np = tvm.topi.testing.batch_matmul(a_np, b_np) + if dtype == "int4": + a_np = np.random.randint(low=-8, high=7, size=(x_batch, M, K)) + b_np = np.random.randint(low=-8, high=7, size=(y_batch, N, K)) + elif dtype == "int8": + a_np = np.random.randint(low=-128, high=127, size=(x_batch, M, K)).astype(dtype) + b_np = np.random.randint(low=-128, high=127, size=(y_batch, N, K)).astype(dtype) + else: + a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype) + b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype) + c_np = tvm.topi.testing.batch_matmul(a_np, b_np, out_dtype) return (a_np, b_np, c_np) # get the test data a_np, b_np, c_np = get_ref_data() + if dtype == "int4": + a_np = convert_int32_into_int4(a_np) + b_np = convert_int32_into_int4(b_np) def check_device(device): dev = tvm.device(device, 0) print("Running on target: %s" % device) with tvm.target.Target(device): fcompute, fschedule = tvm.topi.testing.dispatch(device, _batch_matmul_implement) - out = fcompute(x, y) + out = fcompute(x, y, None, out_dtype) s = fschedule([out]) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=dtype), dev) - f = tvm.build(s, [x, y, out], device, name="dense") + c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out_dtype), dev) + f = tvm.build(s, [x, y, out], device, name="batch_matmul") f(a, b, c) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) @@ -65,10 +102,12 @@ def check_device(device): @tvm.testing.requires_tensorcore def test_batch_matmul(): - verify_batch_matmul(1, 1, 16, 16, 32) - verify_batch_matmul(5, 5, 16, 16, 32) - verify_batch_matmul(5, 5, 16, 32, 32) - verify_batch_matmul(30, 30, 16, 32, 32) + for dtype in ["float16", "int8", "int4"]: + print(dtype) + verify_batch_matmul(1, 1, 16, 16, 32, dtype) + verify_batch_matmul(5, 5, 16, 16, 32, dtype) + verify_batch_matmul(5, 5, 16, 32, 32, dtype) + verify_batch_matmul(30, 30, 16, 32, 32, dtype) if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_dense_tensorcore.py b/tests/python/topi/python/test_topi_dense_tensorcore.py index a3657af2c1ca..1b0a7490ce8e 100644 --- a/tests/python/topi/python/test_topi_dense_tensorcore.py +++ b/tests/python/topi/python/test_topi_dense_tensorcore.py @@ -29,40 +29,98 @@ _dense_implement = {"gpu": [(topi.cuda.dense_tensorcore, topi.cuda.schedule_dense_tensorcore)]} -def verify_dense(batch, in_dim, out_dim, use_bias=True): +def convert_int32_into_int4(a_int32): + """convert int32 values into int4 + Parameters + ---------- + a_int32 : int + + Return + ------ + a_int4 : int + """ + K, L = a_int32.shape + assert L % 8 == 0 + a_int4 = np.zeros(shape=(K, L // 8), dtype=np.int32) + for k in range(K): + for l in range(L // 8): + for m in range(min(8, L - l * 8)): + a_int4[k, l] = a_int4[k, l] | ( + (a_int32[k, l * 8 + m] & 0xF) << ((7 - m) * 4) + ) + return a_int4 + + +def convert_int32_into_int4_bias(a_int32): + """convert int32 values into int4 + Parameters + ---------- + a_int32 : int + + Return + ------ + a_int4 : int + """ + L, = a_int32.shape + assert L % 8 == 0 + a_int4 = np.zeros(shape=(L // 8), dtype=np.int32) + for l in range(L // 8): + for m in range(min(8, L - l * 8)): + a_int4[l] = a_int4[l] | ( + (a_int32[l * 8 + m] & 0xF) << ((7 - m) * 4) + ) + return a_int4 + + +def verify_dense(batch, in_dim, out_dim, dtype, use_bias=True): """Dense tensorcore verify function""" - A = te.placeholder((batch, in_dim), name="A") - B = te.placeholder((out_dim, in_dim), name="B") - C = te.placeholder((out_dim,), name="C") - dtype = A.dtype + A = te.placeholder((batch, in_dim), name="A", dtype=dtype) + B = te.placeholder((out_dim, in_dim), name="B", dtype=dtype) + C = te.placeholder((out_dim,), name="C", dtype=dtype) + + assert dtype in ["int4", "int8", "float16"] + + out_dtype = "float32" + if dtype in ["int8", "int4"]: + out_dtype = "int32" # use memoize to pickle the test data for next time use @memoize("topi.tests.test_topi_dense_tensorcore") def get_ref_data(): - a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) - b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) - c_np = np.random.uniform(size=(out_dim,)).astype(dtype) - if use_bias: - d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0) + if dtype == "int4": + a_np = np.random.randint(low=-8, high=7, size=(batch, in_dim)) + b_np = np.random.randint(low=-8, high=7, size=(out_dim, in_dim)) + c_np = np.random.randint(low=-8, high=7, size=(out_dim,)) + elif dtype == "int8": + a_np = np.random.randint(low=-128, high=127, size=(batch, in_dim)).astype(dtype) + b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(dtype) + c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(dtype) else: - d_np = np.maximum(np.dot(a_np, b_np.T), 0.0) + a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) + b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) + c_np = np.random.uniform(size=(out_dim,)).astype(dtype) + d_np = tvm.topi.testing.dense(a_np, b_np, c_np, use_bias, True, out_dtype) return (a_np, b_np, c_np, d_np) # get the test data a_np, b_np, c_np, d_np = get_ref_data() + if dtype == "int4": + a_np = convert_int32_into_int4(a_np) + b_np = convert_int32_into_int4(b_np) + c_np = convert_int32_into_int4_bias(c_np) def check_device(device): dev = tvm.device(device, 0) print("Running on target: %s" % device) for fcompute, fschedule in tvm.topi.testing.dispatch(device, _dense_implement): with tvm.target.Target(device): - D = fcompute(A, B, C if use_bias else None) + D = fcompute(A, B, C if use_bias else None, out_dtype) D = topi.nn.relu(D) s = fschedule([D]) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(c_np, dev) - d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), dev) + d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), dev) f = tvm.build(s, [A, B, C, D], device, name="dense") f(a, b, c, d) tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-3) @@ -73,11 +131,17 @@ def check_device(device): @tvm.testing.requires_tensorcore def test_dense_tensorcore(): """Test cases""" - verify_dense(8, 16, 32, use_bias=True) - verify_dense(16, 32, 16, use_bias=True) - verify_dense(256, 1024, 1024, use_bias=True) - verify_dense(1000, 1024, 1024, use_bias=False) - verify_dense(256, 2048, 1000, use_bias=False) + for dtype in ["float16", "int8"]: + verify_dense(8, 16, 32, "float16", use_bias=True) + verify_dense(16, 32, 16, dtype, use_bias=True) + verify_dense(256, 1024, 1024, dtype, use_bias=True) + verify_dense(1000, 1024, 1024, dtype, use_bias=False) + verify_dense(256, 2048, 1000, dtype, use_bias=False) + #TODO: need fix int4 use_bias=True, wyc-ruiker + verify_dense(16, 32, 16, "int4", use_bias=False) + verify_dense(256, 1024, 1024, "int4", use_bias=False) + verify_dense(1000, 1024, 1024, "int4", use_bias=False) + verify_dense(256, 2048, 1000, "int4", use_bias=False) if __name__ == "__main__": From 9d906521a2ed9f1b5d42a05237e1e64a43f8d3d2 Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Fri, 2 Jul 2021 20:18:38 +0800 Subject: [PATCH 05/11] fix bug --- python/tvm/topi/cuda/tensorcore_alter_op.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index f6a7757702ee..4b5c4e99186d 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -95,7 +95,7 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) else: y_ = y - out_ = relay.nn.batch_matmul(x_, y_) + out_ = relay.nn.batch_matmul(x_, y_, attrs.out_dtype) if dm or dn: original_out_shape = [x.value for x in output_tensor.shape] out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape) @@ -122,6 +122,7 @@ def _dense_legalize(attrs, inputs, arg_types): result : tvm.relay.Expr The legalized expr """ + new_attrs = {k: attrs[k] for k in attrs.keys()} # Collect the input tensors. x_tensor, y_tensor = arg_types[0], arg_types[1] dtype = x_tensor.dtype @@ -178,7 +179,7 @@ def _dense_legalize(attrs, inputs, arg_types): y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) else: y_ = y - out_ = relay.nn.dense(x_, y_) + out_ = relay.nn.dense(x_, y_, **new_attrs) if dm or dn: original_out_shape = [x.value for x in output_tensor.shape] out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape) From 184199f699f34b3e72262c9866cf71053be6ca56 Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Fri, 2 Jul 2021 20:25:40 +0800 Subject: [PATCH 06/11] fix --- tests/python/topi/python/test_topi_batch_matmul_tensorcore.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py index 5def380ba1b1..971475d8ec63 100644 --- a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py +++ b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py @@ -103,7 +103,6 @@ def check_device(device): @tvm.testing.requires_tensorcore def test_batch_matmul(): for dtype in ["float16", "int8", "int4"]: - print(dtype) verify_batch_matmul(1, 1, 16, 16, 32, dtype) verify_batch_matmul(5, 5, 16, 16, 32, dtype) verify_batch_matmul(5, 5, 16, 32, 32, dtype) From 0c013af745fa090727a1c95945c11fcd01a42f69 Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Mon, 5 Jul 2021 14:49:03 +0800 Subject: [PATCH 07/11] fix lint --- python/tvm/relay/op/strategy/cuda.py | 20 ++++++------------- .../tvm/topi/cuda/batch_matmul_tensorcore.py | 8 ++++---- python/tvm/topi/cuda/dense_tensorcore.py | 8 ++++---- python/tvm/topi/cuda/tensorcore_alter_op.py | 4 ++-- python/tvm/topi/testing/dense.py | 2 +- .../test_topi_batch_matmul_tensorcore.py | 3 ++- .../topi/python/test_topi_dense_tensorcore.py | 12 ++++------- 7 files changed, 23 insertions(+), 34 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index c9c611e5f631..1f999a810164 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -845,21 +845,13 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): _, M, K = get_const_tuple(x.shape) _, N, K = get_const_tuple(y.shape) if ( - ( - x.dtype in ["float16", "int8", "uint8"] - and ( - (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) - or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) - or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) - ) + x.dtype in ["float16", "int8", "uint8"] + and ( + (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) + or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) + or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) ) - or ( - x.dtype in ["int4", "uint4"] - and K % 32 == 0 - and M % 8 == 0 - and N % 8 == 0 - ) - ): + ) or (x.dtype in ["int4", "uint4"] and K % 32 == 0 and M % 8 == 0 and N % 8 == 0): strategy.add_implementation( wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore, need_out_dtype=True), wrap_topi_schedule(topi.cuda.schedule_batch_matmul_tensorcore), diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index 67dd6d8c892e..90bad8fb433d 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -306,13 +306,13 @@ def batch_matmul_tensorcore_cuda(x, y, out_dtype=None): or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)" else: - assert(M % 8 == 0 and K % 32 == 0 and N % 8 == 0), "The shape of (M, K, N) must be multiple of (8, 32, 8)" + assert( + M % 8 == 0 and K % 32 == 0 and N % 8 == 0 + ), "The shape of (M, K, N) must be multiple of (8, 32, 8)" k = te.reduce_axis((0, K), name="k") return te.compute( (batch, M, N), - lambda b, i, j: te.sum( - x[b, i, k].astype(out_dtype) * y[b, j, k].astype(out_dtype), axis=k - ), + lambda b, i, j: te.sum(x[b, i, k].astype(out_dtype) * y[b, j, k].astype(out_dtype), axis=k), tag="batch_matmul_tensorcore", ) diff --git a/python/tvm/topi/cuda/dense_tensorcore.py b/python/tvm/topi/cuda/dense_tensorcore.py index d82c522eb3cb..12d270874e6c 100644 --- a/python/tvm/topi/cuda/dense_tensorcore.py +++ b/python/tvm/topi/cuda/dense_tensorcore.py @@ -73,14 +73,14 @@ def dense_tensorcore_cuda(data, weight, bias=None, out_dtype=None): "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" ) else: - assert(batch % 8 == 0 and in_dim % 32 == 0 and out_dim % 8 == 0), "The shape of (batch, in_dim, out_dim) must be multiple of (8, 32, 8)" + assert( + batch % 8 == 0 and in_dim % 32 == 0 and out_dim % 8 == 0 + ), "The shape of (batch, in_dim, out_dim) must be multiple of (8, 32, 8)" k = te.reduce_axis((0, in_dim), name="k") matmul = te.compute( (batch, out_dim), - lambda i, j: te.sum( - data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k - ), + lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k), name="T_dense", tag="dense_tensorcore", ) diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index 4b5c4e99186d..7eb6dbd5ed51 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -72,7 +72,7 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): return None candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] elif dtype in ["int4", "uint4"]: - if (M % 8 == 0 and K % 32 == 0 and N % 8 == 0): + if M % 8 == 0 and K % 32 == 0 and N % 8 == 0: # no need to pad return None @@ -156,7 +156,7 @@ def _dense_legalize(attrs, inputs, arg_types): candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] elif dtype in ["int4", "uint4"]: - if (M % 8 == 0 and K % 32 == 0 and N % 8 == 0): + if M % 8 == 0 and K % 32 == 0 and N % 8 == 0: # no need to pad return None candidates = [(8, 32, 8)] diff --git a/python/tvm/topi/testing/dense.py b/python/tvm/topi/testing/dense.py index c21b2825b2ac..7871cd71892a 100644 --- a/python/tvm/topi/testing/dense.py +++ b/python/tvm/topi/testing/dense.py @@ -29,7 +29,7 @@ def dense(x, y, bias, use_bias=False, use_relu=False, out_dtype=None): y : numpy.ndarray 2-D with shape [N, K] - + bias: numpy.ndarray 1-D with shape [M,] diff --git a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py index 971475d8ec63..eb657a329889 100644 --- a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py +++ b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py @@ -29,6 +29,7 @@ "gpu": (topi.cuda.batch_matmul_tensorcore, topi.cuda.schedule_batch_matmul_tensorcore), } + def convert_int32_into_int4(a_int32): """convert int32 values into int4 Parameters @@ -70,7 +71,7 @@ def get_ref_data(): b_np = np.random.randint(low=-8, high=7, size=(y_batch, N, K)) elif dtype == "int8": a_np = np.random.randint(low=-128, high=127, size=(x_batch, M, K)).astype(dtype) - b_np = np.random.randint(low=-128, high=127, size=(y_batch, N, K)).astype(dtype) + b_np = np.random.randint(low=-128, high=127, size=(y_batch, N, K)).astype(dtype) else: a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype) b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype) diff --git a/tests/python/topi/python/test_topi_dense_tensorcore.py b/tests/python/topi/python/test_topi_dense_tensorcore.py index 1b0a7490ce8e..7e7d3f2209d3 100644 --- a/tests/python/topi/python/test_topi_dense_tensorcore.py +++ b/tests/python/topi/python/test_topi_dense_tensorcore.py @@ -45,9 +45,7 @@ def convert_int32_into_int4(a_int32): for k in range(K): for l in range(L // 8): for m in range(min(8, L - l * 8)): - a_int4[k, l] = a_int4[k, l] | ( - (a_int32[k, l * 8 + m] & 0xF) << ((7 - m) * 4) - ) + a_int4[k, l] = a_int4[k, l] | ((a_int32[k, l * 8 + m] & 0xF) << ((7 - m) * 4)) return a_int4 @@ -61,14 +59,12 @@ def convert_int32_into_int4_bias(a_int32): ------ a_int4 : int """ - L, = a_int32.shape + (L,) = a_int32.shape assert L % 8 == 0 a_int4 = np.zeros(shape=(L // 8), dtype=np.int32) for l in range(L // 8): for m in range(min(8, L - l * 8)): - a_int4[l] = a_int4[l] | ( - (a_int32[l * 8 + m] & 0xF) << ((7 - m) * 4) - ) + a_int4[l] = a_int4[l] | ((a_int32[l * 8 + m] & 0xF) << ((7 - m) * 4)) return a_int4 @@ -137,7 +133,7 @@ def test_dense_tensorcore(): verify_dense(256, 1024, 1024, dtype, use_bias=True) verify_dense(1000, 1024, 1024, dtype, use_bias=False) verify_dense(256, 2048, 1000, dtype, use_bias=False) - #TODO: need fix int4 use_bias=True, wyc-ruiker + # TODO: need fix int4 use_bias=True, wyc-ruiker verify_dense(16, 32, 16, "int4", use_bias=False) verify_dense(256, 1024, 1024, "int4", use_bias=False) verify_dense(1000, 1024, 1024, "int4", use_bias=False) From 94f3f0a4b9155ec4d398efbcb2e129b0ab9c00f7 Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Mon, 5 Jul 2021 14:52:32 +0800 Subject: [PATCH 08/11] fix lint --- python/tvm/topi/cuda/batch_matmul_tensorcore.py | 2 +- python/tvm/topi/cuda/dense_tensorcore.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index 90bad8fb433d..8b911d43c0e9 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -306,7 +306,7 @@ def batch_matmul_tensorcore_cuda(x, y, out_dtype=None): or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)" else: - assert( + assert ( M % 8 == 0 and K % 32 == 0 and N % 8 == 0 ), "The shape of (M, K, N) must be multiple of (8, 32, 8)" diff --git a/python/tvm/topi/cuda/dense_tensorcore.py b/python/tvm/topi/cuda/dense_tensorcore.py index 12d270874e6c..9b2b3d85b77f 100644 --- a/python/tvm/topi/cuda/dense_tensorcore.py +++ b/python/tvm/topi/cuda/dense_tensorcore.py @@ -73,7 +73,7 @@ def dense_tensorcore_cuda(data, weight, bias=None, out_dtype=None): "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" ) else: - assert( + assert ( batch % 8 == 0 and in_dim % 32 == 0 and out_dim % 8 == 0 ), "The shape of (batch, in_dim, out_dim) must be multiple of (8, 32, 8)" From 74365f6cfec435a73675ec9cad98f824349c598c Mon Sep 17 00:00:00 2001 From: Wang Yucheng Date: Wed, 7 Jul 2021 19:59:42 +0800 Subject: [PATCH 09/11] Apply suggestions from code review Co-authored-by: Chenfan --- python/tvm/topi/cuda/tensorcore_alter_op.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index 7eb6dbd5ed51..4392485b79b9 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -87,20 +87,10 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): return None logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops) - if dm or dk: - x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) - else: - x_ = x - if dn or dk: - y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) - else: - y_ = y + x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) if dm or dk else x + y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) if dn or dk else y out_ = relay.nn.batch_matmul(x_, y_, attrs.out_dtype) - if dm or dn: - original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape) - else: - out = out_ + out = relay.strided_slice(out_, begin=[0, 0, 0], end=[x.value for x in output_tensor.shape]) if dm or dn else out_ return out From 26eb1767a7182b95ab61cbff38ab78da8d3ec80e Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Wed, 7 Jul 2021 20:23:05 +0800 Subject: [PATCH 10/11] fix for reviewer --- .../tvm/topi/cuda/batch_matmul_tensorcore.py | 4 ++- python/tvm/topi/cuda/dense_tensorcore.py | 4 ++- python/tvm/topi/cuda/tensorcore_alter_op.py | 26 +++++++++---------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index 8b911d43c0e9..91ced70c5a96 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -107,9 +107,11 @@ def _schedule(cfg, s, C): wmma_n = 32 elif wmma_m == 32: wmma_n = 8 - else: + elif data_dtype in ["int4", "uint4"]: wmma_m = wmma_n = 8 wmma_k = 32 + else: + raise ValueError('data dtype %s is not yet supported' % data_dtype) warp_size = 32 block_row_warps = cfg["block_row_warps"].val diff --git a/python/tvm/topi/cuda/dense_tensorcore.py b/python/tvm/topi/cuda/dense_tensorcore.py index 9b2b3d85b77f..cc2a40f8d742 100644 --- a/python/tvm/topi/cuda/dense_tensorcore.py +++ b/python/tvm/topi/cuda/dense_tensorcore.py @@ -147,9 +147,11 @@ def _schedule_dense_tensorcore(cfg, s, C): wmma_n = 32 elif wmma_m == 32: wmma_n = 8 - else: + elif data_dtype in ["int4", "uint4"]: wmma_m = wmma_n = 8 wmma_k = 32 + else: + raise ValueError('data dtype %s is not yet supported' % data_dtype) warp_size = 32 block_row_warps = cfg["block_row_warps"].val diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index 4392485b79b9..fffb0d6d48fc 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -90,7 +90,11 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) if dm or dk else x y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) if dn or dk else y out_ = relay.nn.batch_matmul(x_, y_, attrs.out_dtype) - out = relay.strided_slice(out_, begin=[0, 0, 0], end=[x.value for x in output_tensor.shape]) if dm or dn else out_ + out = ( + relay.strided_slice(out_, begin=[0, 0, 0], end=[x.value for x in output_tensor.shape]) + if dm or dn + else out_ + ) return out @@ -161,20 +165,14 @@ def _dense_legalize(attrs, inputs, arg_types): logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio) - if dm or dk: - x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) - else: - x_ = x - if dn or dk: - y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) - else: - y_ = y + x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) if dm or dk else x + y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) if dn or dk else y out_ = relay.nn.dense(x_, y_, **new_attrs) - if dm or dn: - original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape) - else: - out = out_ + out = ( + relay.strided_slice(out_, begin=[0, 0], end=[x.value for x in output_tensor.shape]) + if dm or dn + else out_ + ) return out From cbd0044e3440b2d3f9f59f8c52e57a1f98d8bf6f Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Wed, 7 Jul 2021 20:27:49 +0800 Subject: [PATCH 11/11] fix lint --- python/tvm/topi/cuda/batch_matmul_tensorcore.py | 2 +- python/tvm/topi/cuda/dense_tensorcore.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index 91ced70c5a96..a56d3c36ba33 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -111,7 +111,7 @@ def _schedule(cfg, s, C): wmma_m = wmma_n = 8 wmma_k = 32 else: - raise ValueError('data dtype %s is not yet supported' % data_dtype) + raise ValueError("data dtype %s is not yet supported" % data_dtype) warp_size = 32 block_row_warps = cfg["block_row_warps"].val diff --git a/python/tvm/topi/cuda/dense_tensorcore.py b/python/tvm/topi/cuda/dense_tensorcore.py index cc2a40f8d742..9bac34cbeaf7 100644 --- a/python/tvm/topi/cuda/dense_tensorcore.py +++ b/python/tvm/topi/cuda/dense_tensorcore.py @@ -151,7 +151,7 @@ def _schedule_dense_tensorcore(cfg, s, C): wmma_m = wmma_n = 8 wmma_k = 32 else: - raise ValueError('data dtype %s is not yet supported' % data_dtype) + raise ValueError("data dtype %s is not yet supported" % data_dtype) warp_size = 32 block_row_warps = cfg["block_row_warps"].val