From 2db385de1b3b209318841fc13cda02de89279c54 Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Sat, 2 May 2020 01:41:47 +0000 Subject: [PATCH 01/21] int4 tensorcore --- src/target/source/codegen_c.cc | 8 +- .../topi/cuda/conv2d_nhwc_tensorcore.py | 174 +++++++++----- topi/python/topi/cuda/tensor_intrin.py | 23 +- .../test_topi_conv2d_nhwc_tensorcore.py | 215 ++++++++++++++++-- 4 files changed, 329 insertions(+), 91 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index adb84e498e5d..08065af1b109 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -595,9 +595,13 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) os << "(("; this->PrintType(l->dtype.element_of(), os); os << " *)" << this->GetVarID(l->buffer_var.get()) - << " + "; + << " + " << "("; this->PrintExpr(l->index, os); - os << ')'; + if (l->dtype.bits() == 4 || + (l->dtype.bits() == 1 && l->dtype.is_int())) { + os << " / " << (32 / l->dtype.bits()); + } + os << "))"; } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { CHECK_EQ(op->args.size(), 3U); os << GetStructRef( diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py index 790db0fe89a0..5b1776941bd1 100644 --- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py +++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py @@ -30,7 +30,7 @@ from .tensor_intrin import intrin_wmma_gemm -def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtype): +def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype, out_dtype): """Compute declaration for tensorcore""" assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 @@ -46,12 +46,20 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = get_const_tuple(Input.shape) - kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape) - assert (batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0) or \ + if in_dtype == 'int4': + kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape) + else: + kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape) + + if in_dtype == 'int4': + assert (batch % 8 == 0 and in_channel % 32 == 0 and num_filter % 8 == 0) + else: + assert (batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0) or \ (batch % 8 == 0 and in_channel % 16 == 0 and num_filter % 32 == 0) or \ (batch % 32 == 0 and in_channel % 16 == 0 and num_filter % 8 == 0), \ "The shape of (batch, in_channel, num_filter) "\ - "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" + "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for fp16 and int8, "\ + "and (8, 32, 8) for int4" # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 @@ -70,16 +78,26 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp # convert data type of input feature maps and weights TransPaddedInput = te.compute( PaddedInput.shape, - lambda n, h, w, c: PaddedInput[n, h, w, c].astype('float16')) + lambda n, h, w, c: PaddedInput[n, h, w, c].astype(in_dtype)) TransFilter = te.compute( - Filter.shape, lambda h, w, i, o: Filter[h, w, i, o].astype('float16')) - Output = te.compute( - (batch, out_height, out_width, out_channel), - lambda nn, yy, xx, ff: te.sum( - TransPaddedInput[nn, yy * stride_h + ry * dilation_h, - xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * - TransFilter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), - name="Conv2dOutput", tag="conv2d_nhwc_tensorcore") + Filter.shape, lambda h, w, i, o: Filter[h, w, i, o].astype(in_dtype)) + if in_dtype == 'int4': + Output = te.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: te.sum( + TransPaddedInput[nn, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + TransFilter[ry, rx, ff, rc].astype(out_dtype), axis=[ry, rx, rc]), + name="Conv2dOutput", tag="conv2d_nhwc_tensorcore") + else: + Output = te.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: te.sum( + TransPaddedInput[nn, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + TransFilter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), + name="Conv2dOutput", tag="conv2d_nhwc_tensorcore") + return Output @@ -90,9 +108,12 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv): trans_paddata, kernel = s[Conv].op.input_tensors in_dtype = trans_paddata.dtype batch, _, _, _ = get_const_tuple(Conv.shape) - _, _, _, out_channels = get_const_tuple(kernel.shape) + + if in_dtype == 'int4': + _, _, out_channels, _ = get_const_tuple(kernel.shape) + else: + _, _, _, out_channels = get_const_tuple(kernel.shape) paddata = s[trans_paddata].op.input_tensors - # inline the pad and dtype transform s[trans_paddata].compute_inline() s[kernel].compute_inline() @@ -120,15 +141,14 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv): cfg.define_knob("warp_row_tiles", [1, 2, 4]) cfg.define_knob("warp_col_tiles", [1, 2, 4]) cfg.define_knob("chunk", [1, 2, 4, 8]) - cfg.define_knob("offset", [0, 8]) + if in_dtype == 'int8': + cfg.define_knob("offset", [0, 16]) + elif in_dtype == 'int4': + cfg.define_knob("offset", [0]) + else: + cfg.define_knob("offset", [0, 8]) cfg.define_knob("vector_width", [1, 2, 4, 8]) - - if (batch % 16 == 0 and out_channels % 16 == 0): - cfg.define_knob("wmma_m", [16, 8, 32]) - elif (batch % 8 == 0 and out_channels % 32 == 0): - cfg.define_knob("wmma_m", [8, 16, 32]) - elif (batch % 32 == 0 and out_channels % 8 == 0): - cfg.define_knob("wmma_m", [32, 16, 8]) + # cfg.define_knob("vector_width", [1]) # fallback support target = tvm.target.Target.current() @@ -143,16 +163,34 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv): warp_col_tiles = cfg["warp_col_tiles"].val chunk = cfg["chunk"].val offset = cfg["offset"].val - wmma_m = cfg["wmma_m"].val vector_width = cfg["vector_width"].val - - wmma_k = 16 - if wmma_m == 16: - wmma_n = 16 - elif wmma_m == 8: - wmma_n = 32 - elif wmma_m == 32: - wmma_n = 8 + block_row_warps = 1 + block_col_warps = 4 + warp_row_tiles = 2 + warp_col_tiles = 2 + chunk = 1 + offset = 0 + vector_width = 1 + + if in_dtype == 'int4': + wmma_m = wmma_n = 8 + wmma_k = 32 + else: + if (batch % 16 == 0 and out_channels % 16 == 0): + cfg.define_knob("wmma_m", [16, 8, 32]) + elif (batch % 8 == 0 and out_channels % 32 == 0): + cfg.define_knob("wmma_m", [8, 16, 32]) + elif (batch % 32 == 0 and out_channels % 8 == 0): + cfg.define_knob("wmma_m", [32, 16, 8]) + wmma_m = cfg["wmma_m"].val + wmma_m = 16 + wmma_k = 16 + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 warp_size = 32 @@ -168,17 +206,20 @@ def get_strides(extents): return [np.prod(extents[i:]).tolist() for i in range(len(extents))] AS_align = chunk * wmma_k + offset - WS_align = warp_col_tiles * block_col_warps * wmma_n + offset + if in_dtype == 'int4': + WS_align = chunk * warp_col_tiles * block_col_warps * wmma_k + offset + WL_strides = get_strides([wmma_k * warp_col_tiles, 1]) + else: + WS_align = warp_col_tiles * block_col_warps * wmma_n + offset + WL_strides = get_strides([wmma_n * warp_col_tiles, 1]) block_factor_n = wmma_m * warp_row_tiles * block_row_warps block_factor_o = wmma_n * warp_col_tiles * block_col_warps CS_align = block_factor_o + offset AS_strides = get_strides([1, 1, AS_align, 1]) AL_strides = get_strides([1, 1, wmma_k, 1]) WS_strides = get_strides([WS_align, 1]) - WL_strides = get_strides([wmma_n * warp_col_tiles, 1]) CL_strides = get_strides([1, 1, wmma_n * warp_col_tiles, 1]) CS_strides = get_strides([1, 1, CS_align, 1]) - # Schedule for output nc, hc, wc, oc = output.op.axis block_k = s[output].fuse(hc, wc) @@ -222,8 +263,8 @@ def get_strides(extents): ko, ki = s[ConvF].split(ic, factor=chunk) s[ConvF].reorder(kh, kw, ko, ki, n, o, nnf, oof, ii) - s[AF].compute_at(s[ConvF], ki) - s[WF].compute_at(s[ConvF], ki) + s[AF].compute_at(s[ConvF], n) + s[WF].compute_at(s[ConvF], n) # Schedule wmma load n, h, w, i = AF.op.axis @@ -231,11 +272,20 @@ def get_strides(extents): i, ii = s[AF].split(i, factor=wmma_k) s[AF].reorder(n, i, nn, ii) - kh, kw, i, o = WF.op.axis - i, ii = s[WF].split(i, factor=wmma_k) - o, oo = s[WF].split(o, factor=wmma_n) - s[WF].reorder(o, i, oo) - s[WF].reorder(i, o, ii, oo) + # kh, kw, i, o = WF.op.axis + if in_dtype == 'int4': + kh, kw, o, i = WF.op.axis + # print('kh, kw, o, i', kh, kw, o, i) + i, ii = s[WF].split(i, factor=wmma_k) + o, oo = s[WF].split(o, factor=wmma_n) + s[WF].reorder(o, i, oo) + s[WF].reorder(o, i, oo, ii) + else: + kh, kw, i, o = WF.op.axis + i, ii = s[WF].split(i, factor=wmma_k) + o, oo = s[WF].split(o, factor=wmma_n) + s[WF].reorder(o, i, oo) + s[WF].reorder(i, o, ii, oo) s[WS].compute_at(s[ConvF], ko) s[AS].compute_at(s[ConvF], ko) @@ -272,37 +322,54 @@ def get_strides(extents): # tensorize the wmma process AS_shape = (wmma_m, 1, 1, wmma_k) AL_shape = (wmma_m, 1, 1, wmma_k) - WS_shape = (wmma_k, wmma_n) - WL_shape = (wmma_k, wmma_n) + if in_dtype == 'int4': + WS_shape = (wmma_n, wmma_k) + WL_shape = (wmma_n, wmma_k) + else: + WS_shape = (wmma_k, wmma_n) + WL_shape = (wmma_k, wmma_n) CL_shape = (wmma_m, 1, 1, wmma_n) CS_shape = (wmma_m, 1, 1, wmma_n) AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype) WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k") - CL_compute = te.compute(CL_shape, lambda ii, t0, t1, jj: - te.sum(AL_gemm[ii, t0, t1, k_gemm].astype(out_dtype) * \ - WL_gemm[k_gemm, jj].astype(out_dtype), axis=k_gemm), - name='C') + if in_dtype == 'int4': + CL_compute = te.compute(CL_shape, lambda ii, t0, t1, jj: + te.sum(AL_gemm[ii, t0, t1, k_gemm].astype(out_dtype) * \ + WL_gemm[jj, k_gemm].astype(out_dtype), axis=k_gemm), + name='C') + else: + CL_compute = te.compute(CL_shape, lambda ii, t0, t1, jj: + te.sum(AL_gemm[ii, t0, t1, k_gemm].astype(out_dtype) * \ + WL_gemm[k_gemm, jj].astype(out_dtype), axis=k_gemm), + name='C') s[AF].tensorize(nn, intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, "row_major", AS_shape, AL_shape, in_dtype)) - s[WF].tensorize(ii, intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, - "row_major", WS_shape, WL_shape, in_dtype)) + if in_dtype == 'int4': + s[WF].tensorize(oo, intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, + "col_major", WS_shape, WL_shape, in_dtype)) + else: + s[WF].tensorize(ii, intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, + "row_major", WS_shape, WL_shape, in_dtype)) s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, shape, out_dtype, CL_shape, CS_shape)) s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, WL_strides, CL_strides, shape)) N, OH, OW, CO = get_const_tuple(output.shape) - KH, KW, CI, _ = get_const_tuple(kernel.shape) + if in_dtype == 'int4': + KH, KW, _, CI = get_const_tuple(kernel.shape) + else: + KH, KW, CI, _ = get_const_tuple(kernel.shape) cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW) @autotvm.register_topi_compute("conv2d_nhwc_tensorcore.cuda") -def conv2d_nhwc_tensorcore(cfg, data, kernel, strides, padding, dilation, out_dtype): +def conv2d_nhwc_tensorcore(cfg, data, kernel, strides, padding, dilation, in_dtype, out_dtype): """Compute conv2d with tensorcore for NCHW layout""" - return nhwc_tensorcore_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype) + return nhwc_tensorcore_cuda(cfg, data, kernel, strides, padding, dilation, in_dtype, out_dtype) @autotvm.register_topi_schedule("conv2d_nhwc_tensorcore.cuda") @@ -316,3 +383,4 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + diff --git a/topi/python/topi/cuda/tensor_intrin.py b/topi/python/topi/cuda/tensor_intrin.py index f8fce342e212..fcb1458984e9 100644 --- a/topi/python/topi/cuda/tensor_intrin.py +++ b/topi/python/topi/cuda/tensor_intrin.py @@ -85,11 +85,11 @@ def intrin_wmma_load_matrix_A(strides_dst, strides_from, shape, layout, A_shape, A = te.placeholder(A_shape, name='A', dtype=in_dtype) BA = tvm.tir.decl_buffer(A.shape, A.dtype, - scope='shared', strides=strides_from, + scope='shared', strides=[te.var("s1"), te.var("s2"), te.var("s3"), te.var("s4")], data_alignment=32, offset_factor=8) C = te.compute(C_shape, lambda *i: A(*i), name='C') BC = tvm.tir.decl_buffer(C.shape, C.dtype, - scope="wmma.matrix_a", strides=strides_dst, + scope="wmma.matrix_a", strides=[te.var("s1"), te.var("s2"), te.var("s3"), te.var("s4")], data_alignment=32, offset_factor=8) def intrin_func(ins, outs): @@ -110,14 +110,13 @@ def intrin_func(ins, outs): def intrin_wmma_load_matrix_W(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype): """Intrin function for loading data from shared memory to wmma.matrix_b""" wmma_m, wmma_n, wmma_k = shape - A = te.placeholder(A_shape, name='A', dtype=in_dtype) BA = tvm.tir.decl_buffer(A.shape, A.dtype, - scope='shared', strides=strides_from, + scope='shared', strides=[te.var("s1"), te.var("s2")], data_alignment=32, offset_factor=8) C = te.compute(C_shape, lambda *i: A(*i), name='C') BC = tvm.tir.decl_buffer(C.shape, C.dtype, - scope="wmma.matrix_b", strides=strides_dst, + scope="wmma.matrix_b", strides=[te.var("s3"), te.var("s4")], data_alignment=32, offset_factor=8) def intrin_func(ins, outs): @@ -183,25 +182,25 @@ def intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, strides_A, BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, - offset_factor=8, strides=strides_A) + offset_factor=8, strides=[te.var("s1"), te.var("s2"), te.var("s3"), te.var("s4")]) BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, - offset_factor=8, strides=strides_W) + offset_factor=8, strides=[te.var("s1"), te.var("s2")]) BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, - offset_factor=8, strides=strides_Conv) + offset_factor=8, strides=[te.var("s1"), te.var("s2"), te.var("s3"), te.var("s4")]) def intrin_func(ins, outs): BA, BB = ins BC, = outs - def warp_idnex(offset, row, col): + def warp_index(offset, row, col): row = row * col return offset // row + offset % row // col - warp_index_A = warp_idnex(BA.elem_offset, wmma_m, wmma_k) - warp_index_B = warp_idnex(BB.elem_offset, wmma_k, wmma_n) - warp_index_C = warp_idnex(BC.elem_offset, wmma_m, wmma_n) + warp_index_A = warp_index(BA.elem_offset, wmma_m, wmma_k) + warp_index_B = warp_index(BB.elem_offset, wmma_k, wmma_n) + warp_index_C = warp_index(BC.elem_offset, wmma_m, wmma_n) def init(): ib = tvm.tir.ir_builder.create() diff --git a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py index cc327849caea..84fd4ea48927 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py @@ -19,14 +19,37 @@ import numpy as np import tvm +import os import topi import topi.testing -from tvm import te +from tvm import te, autotvm from tvm.contrib.pickle_memoize import memoize from tvm.contrib import nvcc from topi.nn.util import get_pad_tuple from topi.util import get_const_tuple +TASK="conv_int4" + +USE_MANUAL_CODE = False + +# @tvm.register_func +# def tvm_callback_cuda_compile(code): +# ptx = nvcc.compile_cuda(code, target="ptx") +# return ptx + +def write_code(code, fname): + with open(fname, "w") as f: + f.write(code) + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + if not os.path.exists("perf"): + os.mkdir("perf") + write_code(code, "perf/%s_generated.cu" % TASK) + if USE_MANUAL_CODE: + code = open("perf/%s_manual.cu" % TASK).read() + return code + _conv2d_nhwc_tensorcore_implement = { "cuda": (topi.cuda.conv2d_nhwc_tensorcore, topi.cuda.schedule_conv2d_nhwc_tensorcore) @@ -41,32 +64,77 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % ( batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + # choose dtype from int4, int8 and float16 + dtype = 'int4' + out_dtype = 'int32' + in_height = in_width = in_size - A = te.placeholder((batch, in_height, in_width, in_channel), name='A') - W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W') - bias = te.placeholder((1, 1, 1, num_filter), name='bias') + A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype=dtype) + if dtype == 'int4': + W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W', dtype=dtype) + else: + W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype=dtype) + + bias = te.placeholder((1, 1, 1, num_filter), name='bias', dtype=out_dtype) a_shape = get_const_tuple(A.shape) w_shape = get_const_tuple(W.shape) bias_shape = get_const_tuple(bias.shape) - dtype = A.dtype + # dtype = A.dtype @memoize("topi.tests.test_topi_conv2d_nhwc.verify_conv2d_nhwc") def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = np.random.uniform(size=bias_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + if dtype == 'float16': + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(out_dtype) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + elif dtype == 'int4': + a_np = np.random.randint(low=1, high=7, size=a_shape) + b_np = np.random.randint(low=1, high=7, size=bias_shape) + w_np = np.random.randint(low=1, high=7, size=w_shape) + dw_np = topi.testing.dilate_python(w_np.transpose((0, 1, 3, 2)), (1, 1, dilation, dilation)) + elif dtype == 'int8': + a_np = np.random.randint(low=1, high=7, size=a_shape).astype(dtype) + w_np = np.random.randint(low=1, high=7, size=w_shape).astype(dtype) + b_np = np.random.randint(low=1, high=7, size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) if add_bias: - b_np = np.random.uniform(size=bias_shape).astype(dtype) + # b_np = np.random.uniform(size=bias_shape).astype(out_dtype) c_np += b_np if add_relu: c_np = np.maximum(c_np, 0) return a_np, w_np, b_np, c_np + + def convert_int32_into_int4(a_int32): + """ convert int32 values into int4 + Parameters + ---------- + a_int32 : int + + Return + ------ + a_int4 : int + """ + I, J, K, L = a_int32.shape + a_int4 = np.zeros(shape=(I, J, K, L // 8), dtype=np.int32) + # for g in range(G): + for i in range(I): + for j in range(J): + for k in range(K): + for l in range(L // 8): + for m in range(min(8, L-l*8)): + a_int4[i, j, k, l] = a_int4[i, j, k, l] | ((a_int32[i, j, k, l * 8 + m] & 0xf) << ((7 - m) * 4)) + return a_int4 a_np, w_np, b_np, c_np = get_ref_data() + if dtype == 'int4': + a_np = convert_int32_into_int4(a_np) + b_np = convert_int32_into_int4(b_np) + w_np = convert_int32_into_int4(w_np) def check_device(device): ctx = tvm.context(device, 0) @@ -79,7 +147,10 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): fcompute, fschedule = topi.testing.dispatch(device, _conv2d_nhwc_tensorcore_implement) - C = fcompute(A, W, stride, padding, dilation, 'float32') + if dtype == 'float16': + C = fcompute(A, W, stride, padding, dilation, dtype, 'float') + else: + C = fcompute(A, W, stride, padding, dilation, dtype, 'int32') if add_bias: C = topi.add(C, bias) if add_relu: @@ -95,31 +166,127 @@ def check_device(device): batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, b, c) else: + # print(tvm.lower(s, [A, W, C], simple_mode=True)) func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, c) + # warm up + evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) + evaluator(a, w, c) + print('Time cost of this operator: %f ms' % (evaluator(a, w, c).mean * 1000)) rtol = 1e-3 tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol) + # # #Tuning the performance + # import logging, sys + # logging.getLogger('autotvm').setLevel(logging.DEBUG) + # logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) + + # log_filename = "conv2d_int4_nhwc_tensorcore.log" + # tmp_log_file = log_filename + '.temp' + # num_trial = 1000 + # task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, + # padding, dilation) + # task = autotvm.create('conv2d_nhwc_tensorcore.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) + # print(task.config_space) + + # measure_option = autotvm.measure_option( + # builder='local', + # runner=autotvm.LocalRunner(number=5)) + + # tuner = autotvm.tuner.XGBTuner(task) + # num_trial = min(num_trial, len(task.config_space)) + # with tvm.target.build_config(): + # tuner.tune(n_trial=num_trial, + # measure_option=measure_option, + # callbacks=[autotvm.callback.progress_bar(num_trial, prefix=task_name), + # autotvm.callback.log_to_file(tmp_log_file)]) + + # dispatch_context = autotvm.apply_history_best(tmp_log_file) + # best_config = dispatch_context.query(task.target, task.workload) + # print("\nBest config:") + # print(best_config) + + # #pick the best record to a cache file + # autotvm.record.pick_best(tmp_log_file, log_filename) + # os.remove(tmp_log_file) + + # with autotvm.apply_graph_best(log_filename): + # with tvm.target.create(device): + # func = tvm.build(s, [A, W, C], device, name="conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, + # padding, dilation)) + # evaluator = func.time_evaluator(func.entry_name, ctx, number=100, repeat=10) + # print('Time cost of this operator after tuning: %f ms' % (evaluator(a, w, c).mean * 1000)) + check_device(devices) def test_conv2d_nhwc_tensorcore(): """Test the conv2d with tensorcore for nhwc layout""" - verify_conv2d_nhwc(16, 16, 14, 16, 3, 1, 1) - verify_conv2d_nhwc(16, 128, 7, 128, 7, 1, 3) - verify_conv2d_nhwc(16, 160, 7, 160, 7, 1, 3) - - verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_bias=True) - verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True) - verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True, add_bias=True) - - verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, (3, 3, 2, 2)) - verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, "SAME") - verify_conv2d_nhwc(16, 48, 35, 48, 5, 1, "VALID") - verify_conv2d_nhwc(16, 48, 56, 48, 3, 1, (1, 1, 1, 1)) - verify_conv2d_nhwc(16, 64, 28, 64, 3, 1, (1, 1, 1, 1)) + # verify_conv2d_nhwc(64, 64, 56, 64, 3, 1, 1) + # verify_conv2d_nhwc(64, 64, 56, 64, 1, 1, 0) + # verify_conv2d_nhwc(64, 64, 56, 128, 3, 2, 1) + # verify_conv2d_nhwc(64, 64, 56, 64, 1, 2, 0) + # verify_conv2d_nhwc(64, 128, 28, 128, 3, 1, 1) + # verify_conv2d_nhwc(64, 128, 28, 256, 3, 2, 1) + # verify_conv2d_nhwc(64, 128, 28, 256, 1, 2, 0) + # verify_conv2d_nhwc(64, 256, 14, 256, 3, 1, 1) + # verify_conv2d_nhwc(64, 256, 14, 512, 3, 2, 1) + # verify_conv2d_nhwc(64, 256, 14, 512, 1, 2, 0) + # verify_conv2d_nhwc(64, 512, 7, 512, 3, 1, 1) + + # verify_conv2d_nhwc(32, 64, 56, 64, 3, 1, 1) + # verify_conv2d_nhwc(32, 64, 56, 64, 1, 1, 0) + # verify_conv2d_nhwc(32, 64, 56, 128, 3, 2, 1) + # verify_conv2d_nhwc(32, 64, 56, 64, 1, 2, 0) + # verify_conv2d_nhwc(32, 128, 28, 128, 3, 1, 1) + # verify_conv2d_nhwc(32, 128, 28, 256, 3, 2, 1) + # verify_conv2d_nhwc(32, 128, 28, 256, 1, 2, 0) + # verify_conv2d_nhwc(32, 256, 14, 256, 3, 1, 1) + # verify_conv2d_nhwc(32, 256, 14, 512, 3, 2, 1) + # verify_conv2d_nhwc(32, 256, 14, 512, 1, 2, 0) + # verify_conv2d_nhwc(32, 512, 7, 512, 3, 1, 1) + + # verify_conv2d_nhwc(16, 64, 56, 64, 3, 1, 1) + # verify_conv2d_nhwc(16, 64, 56, 64, 1, 1, 0) + # verify_conv2d_nhwc(16, 64, 56, 128, 3, 2, 1) + # verify_conv2d_nhwc(16, 64, 56, 64, 1, 2, 0) + # verify_conv2d_nhwc(16, 128, 28, 128, 3, 1, 1) + # verify_conv2d_nhwc(16, 128, 28, 256, 3, 2, 1) + # verify_conv2d_nhwc(16, 128, 28, 256, 1, 2, 0) + # verify_conv2d_nhwc(16, 256, 14, 256, 3, 1, 1) + # verify_conv2d_nhwc(16, 256, 14, 512, 3, 2, 1) + # verify_conv2d_nhwc(16, 256, 14, 512, 1, 2, 0) + # verify_conv2d_nhwc(16, 512, 7, 512, 3, 1, 1) + + verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) + # verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) + verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) + # verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) + # verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1) + # verify_conv2d_nhwc(8, 128, 28, 256, 3, 2, 1) + # verify_conv2d_nhwc(8, 128, 28, 256, 1, 2, 0) + # verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1) + # verify_conv2d_nhwc(8, 256, 14, 512, 3, 2, 1) + # verify_conv2d_nhwc(8, 256, 14, 512, 1, 2, 0) + # verify_conv2d_nhwc(8, 512, 7, 512, 3, 1, 1) + + + # verify_conv2d_nhwc(32, 1024, 14, 256, 1, 1, 1) + + # verify_conv2d_nhwc(16, 128, 7, 128, 7, 1, 3) + # verify_conv2d_nhwc(16, 160, 7, 160, 7, 1, 3) + + # verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_bias=True) + # verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True) + # verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True, add_bias=True) + + # verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, (3, 3, 2, 2)) + # verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, "SAME") + # verify_conv2d_nhwc(16, 48, 35, 48, 5, 1, "VALID") + # verify_conv2d_nhwc(16, 48, 56, 48, 3, 1, (1, 1, 1, 1)) + # verify_conv2d_nhwc(16, 64, 28, 64, 3, 1, (1, 1, 1, 1)) if __name__ == "__main__": From 23ac8f9ee7b693833df6beaeecdee448a70da0de Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Tue, 5 May 2020 02:25:55 +0000 Subject: [PATCH 02/21] a draft for new int4 schedule --- .../topi/cuda/conv2d_nhwc_tensorcore_int4.py | 395 ++++++++++++++++++ .../test_topi_conv2d_nhwc_tensorcore_int4.py | 335 +++++++++++++++ 2 files changed, 730 insertions(+) create mode 100644 topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py create mode 100644 topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py new file mode 100644 index 000000000000..186ed7e9fc6b --- /dev/null +++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py @@ -0,0 +1,395 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-function-args +# pylint: disable=too-many-statements, unused-argument, too-many-arguments +"""Tensorcore template for cuda backend""" +import numpy as np +import tvm +from tvm import te +from tvm import autotvm +from ..util import get_const_tuple, traverse_inline, simplify +from ..nn.pad import pad +from ..nn.util import get_pad_tuple +# from .tensor_intrin import intrin_wmma_load_matrix_A +# from .tensor_intrin import intrin_wmma_load_matrix_W +# from .tensor_intrin import intrin_wmma_store_matrix +# from .tensor_intrin import intrin_wmma_gemm + +def intrin_wmma_load_matrix(scope): + n = m = 8 + l = 32 + if scope == 'wmma.matrix_a': + A = tvm.te.placeholder((n, l), name='A', dtype='int4') + C = tvm.te.compute((n, l), lambda i, j: A[i, j], name='C') + else: + A = tvm.te.placeholder((m, l), name='A', dtype='int4') + C = tvm.te.compute((m, l), lambda i, j: A[i, j], name='C') + # A = te.placeholder((n, m), name='A', dtype='int4') + # C = te.compute((m, n), lambda i, j: A[i, j], name='C') + BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) + BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + + BA = ins[0] + BC = outs[0] + if scope == "wmma.matrix_a": + ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + BC.data, n, m, l, BC.elem_offset // 256, + BA.access_ptr('r'), l, 'row_major')) + elif scope == "wmma.matrix_b": + ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + BC.data, n, m, l, BC.elem_offset // 256, + BA.access_ptr('r'), l, 'col_major')) + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + + +def intrin_wmma_gemm(): + n = m = 8 + l = 32 + A = te.placeholder((n, l), name='A', dtype='int4') + B = te.placeholder((n, l), name='B', dtype='int4') + k = te.reduce_axis((0, l), name="k") + C = te.compute((n, n), + lambda ii, jj: + te.sum(A[ii, k].astype('int32') * B[jj, k].astype('int32'), axis=k), + name='C') + BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) + BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) + BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=64) + + def intrin_func(ins, outs): + BA, BB = ins + BC, = outs + + def init(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, m, l, BC.elem_offset // n * n, 0.0)) + return ib.get() + + def update(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', + BC.data, BC.elem_offset // 64, + BA.data, BA.elem_offset // 256, + BB.data, BB.elem_offset // 256, + BC.data, BC.elem_offset // 64)) + return ib.get() + + return update(), init(), update() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + + +def intrin_wmma_store_matrix(): + n = m = 8 + l = 32 + A = te.placeholder((n, m), name='A', dtype='int32') + BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=64) + C = te.compute((n, m), lambda i, j: A[i, j], name='C') + BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=64) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + BA = ins[0] + BC = outs[0] + ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', + BA.data, n, m, l, BA.elem_offset // 64, + BC.access_ptr('w'), n, 'row_major')) + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + +def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype, out_dtype): + """Compute declaration for tensorcore""" + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + wmma_n = wmma_m = 8 + wmma_k = 32 + + batch, in_height, in_width, in_channels, wmma_m, wmma_k = get_const_tuple(Input.shape) + if in_dtype == 'int4': + kernel_h, kernel_w, _, num_filter, wmma_n, wmma_k = get_const_tuple(Filter.shape) + else: + kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape) + + if in_dtype == 'int4': + pass + # assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0) + else: + assert (batch % 16 == 0 and in_channels % 16 == 0 and num_filter % 16 == 0) or \ + (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0) or \ + (batch % 32 == 0 and in_channels % 16 == 0 and num_filter % 8 == 0), \ + "The shape of (batch, in_channels, num_filter) "\ + "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for fp16 and int8, "\ + "and (8, 32, 8) for int4" + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + out_channels = num_filter + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + # PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + # Input feature map: (N, H, W, IC, n, ic) + data_shape = (batch, + in_height, + in_width, + in_channels, + wmma_m, + wmma_k) + # Kernel: (H, W, IC, OC, ic, oc) + kernel_shape = (kernel_h, + kernel_w, + in_channels, + out_channels, + wmma_n, + wmma_k) + output_shape = (batch, + out_height, + out_width, + out_channels, + wmma_m, + wmma_n) + # rc = te.reduce_axis((0, in_channel), name='rc') + # ry = te.reduce_axis((0, kernel_h), name='ry') + # rx = te.reduce_axis((0, kernel_w), name='rx') + # Reduction axes + kh = te.reduce_axis((0, kernel_h), name='kh') + kw = te.reduce_axis((0, kernel_w), name='kw') + ic = te.reduce_axis((0, in_channels), name='ic') + ii = te.reduce_axis((0, wmma_k), name='ii') + # Algorithm + # A = te.placeholder(data_shape, name='A', dtype="int4") + # W = te.placeholder(kernel_shape, name='W', dtype="int4") + Apad = te.compute( + (batch, in_height + 2 * padding, in_width + 2 * padding, in_channels, wmma_m, + wmma_k), + lambda n, h, w, i, nn, ii: tvm.tir.if_then_else( + tvm.tir.all(h >= padding, h - padding < in_height, + w >= padding, w - padding < in_width), + Input[n, h - padding, w - padding, i, nn, ii], tvm.tir.const(0., "int4")), + name='Apad') + Conv = te.compute(output_shape, + lambda n, h, w, o, nn, oo: te.sum( + Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("int32") * + Filter[kh, kw, ic, o, oo, ii].astype("int32"), + axis=[ic, kh, kw, ii]), + name="Conv", tag="conv2d_nhwc_tensorcore_int4") + + return Conv + + +def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Conv): + """Schedule tensorcore template""" + ic, kh, kw, ii = s[Conv].op.reduce_axis + out_dtype = Conv.dtype + # trans_paddata, kernel = s[Conv].op.input_tensors + Apad, kernel = s[Conv].op.input_tensors + s[Apad].compute_inline() + in_dtype = Apad.dtype + batch, _, _, _, _, _ = get_const_tuple(Conv.shape) + if in_dtype == 'int4': + _, _, out_channels, _, _, _ = get_const_tuple(kernel.shape) + else: + _, _, _, out_channels, _, _ = get_const_tuple(kernel.shape) + # inline the pad and dtype transform + # s[kernel].compute_inline() + # s[paddata[0]].compute_inline() + + block_x = te.thread_axis('blockIdx.x') + block_y = te.thread_axis('blockIdx.y') + block_z = te.thread_axis('blockIdx.z') + thread_x = te.thread_axis('threadIdx.x') + thread_y = te.thread_axis('threadIdx.y') + thread_z = te.thread_axis('threadIdx.z') + + # Designate the memory hierarchy + AS = s.cache_read(Apad, 'shared', [Conv]) + WS = s.cache_read(kernel, 'shared', [Conv]) + AF = s.cache_read(AS, 'wmma.matrix_a', [Conv]) + WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) + ConvF = s.cache_write(Conv, 'wmma.accumulator') + + # todo + # if Conv.op in s.outputs: + # output = Conv + # ConvS = s.cache_read(ConvF, 'shared', [Conv]) + # OL = ConvS + # else: + # output = s.outputs[0].output(0) + # s[Conv].set_scope('shared') + # OL = Conv + + # Schedule for autotvm + cfg.define_knob("block_row_warps", [1, 2, 4, 8]) + cfg.define_knob("block_col_warps", [1, 2, 4, 8]) + cfg.define_knob("warp_row_tiles", [1, 2, 4, 8]) + cfg.define_knob("warp_col_tiles", [1, 2, 4, 8]) + cfg.define_knob("chunk", [1, 2, 4, 8]) + # if in_dtype == 'int8': + # cfg.define_knob("offset", [0, 16]) + # elif in_dtype == 'int4': + # cfg.define_knob("offset", [0]) + # else: + # cfg.define_knob("offset", [0, 8]) + # cfg.define_knob("vector_width", [1, 2, 4, 8]) + cfg.define_knob("vector_width", [1, 8]) + + # fallback support + target = tvm.target.Target.current() + if cfg.is_fallback: + ref_log = autotvm.tophub.load_reference_log( + target.target_name, target.model, 'conv2d_nhwc_tensorcore_int4.cuda') + cfg.fallback_with_reference_log(ref_log) + + block_row_warps = cfg["block_row_warps"].val + block_col_warps = cfg["block_col_warps"].val + warp_row_tiles = cfg["warp_row_tiles"].val + warp_col_tiles = cfg["warp_col_tiles"].val + chunk = cfg["chunk"].val + # offset = cfg["offset"].val + vector_width = cfg["vector_width"].val + block_row_warps = 1 + block_col_warps = 8 + warp_row_tiles = 2 + warp_col_tiles = 1 + chunk = 4 + vector_width = 1 + + # offset = 0 + + if in_dtype == 'int4': + wmma_m = wmma_n = 8 + wmma_k = 32 + else: + if (batch % 16 == 0 and out_channels % 16 == 0): + cfg.define_knob("wmma_m", [16, 8, 32]) + elif (batch % 8 == 0 and out_channels % 32 == 0): + cfg.define_knob("wmma_m", [8, 16, 32]) + elif (batch % 32 == 0 and out_channels % 8 == 0): + cfg.define_knob("wmma_m", [32, 16, 8]) + wmma_m = cfg["wmma_m"].val + # wmma_m = 16 + wmma_k = 16 + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + + warp_size = 32 + + nc, hc, wc, oc, nnc, ooc = Conv.op.axis + block_k = s[Conv].fuse(hc, wc) + s[Conv].bind(block_k, block_z) + nc, nci = s[Conv].split(nc, factor=warp_row_tiles) + block_i, nc = s[Conv].split(nc, factor=block_row_warps) + oc, oci = s[Conv].split(oc, factor=warp_col_tiles) + block_j, oc = s[Conv].split(oc, factor=block_col_warps) + s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc) + s[Conv].bind(block_i, block_x) + s[Conv].bind(block_j, block_y) + s[Conv].bind(nc, thread_y) + s[Conv].bind(oc, thread_z) + # Schedule local computation + s[ConvF].compute_at(s[Conv], oc) + n, h, w, o, nnf, oof = ConvF.op.axis + ko, ki = s[ConvF].split(ic, factor=chunk) + s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii) + + # Move intermediate computation into each output compute tile + s[AF].compute_at(s[ConvF], kw) + s[WF].compute_at(s[ConvF], kw) + + # vector_width=8 + # Schedule for A's share memory + s[AS].compute_at(s[ConvF], kh) + n, h, w, i, nn, ii = AS.op.axis + tx, xo = s[AS].split(n, nparts=block_row_warps) + ty, yo = s[AS].split(xo, nparts=block_col_warps) + t = s[AS].fuse(nn, ii) + to, ti = s[AS].split(t, factor=warp_size) + # ti, _t = s[AS].split(ti, factor=vector_width) + s[AS].bind(tx, thread_y) + s[AS].bind(ty, thread_z) + s[AS].bind(ti, thread_x) + # s[AS].vectorize(ti) + + # Schedule for W's share memory + s[WS].compute_at(s[ConvF], kh) + kh, kw, ic, o, ii, oo = WS.op.axis + tx, xo = s[WS].split(o, nparts=block_row_warps) + ty, yo = s[WS].split(xo, nparts=block_col_warps) + t = s[WS].fuse(ii, oo) + to, ti = s[WS].split(t, nparts=warp_size) + ti, _t = s[WS].split(ti, factor=vector_width) + s[WS].bind(tx, thread_y) + s[WS].bind(ty, thread_z) + s[WS].bind(to, thread_x) + s[WS].vectorize(ti) + + s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a')) + s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b')) + s[Conv].tensorize(nnc, intrin_wmma_store_matrix()) + s[ConvF].tensorize(nnf, intrin_wmma_gemm()) + + + N, OH, OW, CO, nn, mm = get_const_tuple(Conv.shape) + if in_dtype == 'int4': + KH, KW, _, CI, _, ci = get_const_tuple(kernel.shape) + else: + KH, KW, CI, _ = get_const_tuple(kernel.shape) + cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * ci * nn * mm) + + +@autotvm.register_topi_compute("conv2d_nhwc_tensorcore_int4.cuda") +def conv2d_nhwc_tensorcore_int4(cfg, data, kernel, strides, padding, dilation, in_dtype, out_dtype): + """Compute conv2d with tensorcore for NCHW layout""" + return nhwc_tensorcore_cuda(cfg, data, kernel, strides, padding, dilation, in_dtype, out_dtype) + + +@autotvm.register_topi_schedule("conv2d_nhwc_tensorcore_int4.cuda") +def schedule_conv2d_nhwc_tensorcore_int4(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + def _callback(op): + if 'conv2d_nhwc_tensorcore_int4' in op.tag: + schedule_nhwc_tensorcore_cuda_int4(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s + diff --git a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py new file mode 100644 index 000000000000..af135dd52a04 --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py @@ -0,0 +1,335 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-arguments +"""Example code to do convolution.""" + +import numpy as np +import tvm +import os +import topi +import topi.testing +from tvm import te, autotvm +from tvm.contrib.pickle_memoize import memoize +from tvm.contrib import nvcc +from topi.nn.util import get_pad_tuple +from topi.util import get_const_tuple + +TASK="conv_int4" + +USE_MANUAL_CODE = False + +# @tvm.register_func +# def tvm_callback_cuda_compile(code): +# ptx = nvcc.compile_cuda(code, target="ptx") +# return ptx + +def write_code(code, fname): + with open(fname, "w") as f: + f.write(code) + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + if not os.path.exists("perf"): + os.mkdir("perf") + write_code(code, "perf/%s_generated.cu" % TASK) + if USE_MANUAL_CODE: + code = open("perf/%s_manual.cu" % TASK).read() + return code + + +_conv2d_nhwc_tensorcore_implement = { + "cuda": (topi.cuda.conv2d_nhwc_tensorcore_int4, topi.cuda.schedule_conv2d_nhwc_tensorcore_int4) +} + + +def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation=1, add_bias=False, add_relu=False, devices='cuda'): + """Test the conv2d with tensorcore for nhwc layout""" + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + + # choose dtype from int4, int8 and float16 + dtype = 'int4' + out_dtype = 'int32' + wmma_n = wmma_m = 8 + wmma_k = 32 + in_height = in_width = in_size + + A = te.placeholder((batch // wmma_m, in_height, in_width, in_channel // wmma_k, wmma_m, wmma_k), name='A', dtype=dtype) + if dtype == 'int4' or dtype == 'int8': + W = te.placeholder((kernel, kernel, in_channel // wmma_k, num_filter // wmma_n , wmma_n, wmma_k), name='W', dtype=dtype) + else: + W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype=dtype) + + bias = te.placeholder((1, 1, 1, num_filter), name='bias', dtype=out_dtype) + + # a_shape = get_const_tuple(A.shape) + # w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + a_shape = (batch, in_height, in_width, in_channel) + # w_shape = (kernel, kernel, num_filter, in_channel) + w_shape = (kernel, kernel, in_channel, num_filter) + # dtype = A.dtype + + @memoize("topi.tests.test_topi_conv2d_nhwc.verify_conv2d_nhwc") + def get_ref_data(): + if dtype == 'float16': + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(out_dtype) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + elif dtype == 'int4': + a_np = np.random.randint(low=-7, high=7, size=a_shape).astype(np.int32) + b_np = np.random.randint(low=-7, high=7, size=bias_shape).astype(np.int32) + w_np = np.random.randint(low=-7, high=7, size=w_shape).astype(np.int32) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + elif dtype == 'int8': + a_np = np.random.randint(low=1, high=7, size=a_shape).astype(dtype) + w_np = np.random.randint(low=1, high=7, size=w_shape).astype(dtype) + b_np = np.random.randint(low=1, high=7, size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + + c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + if add_bias: + # b_np = np.random.uniform(size=bias_shape).astype(out_dtype) + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + return a_np, w_np, b_np, c_np + + def convert_int32_into_int4(a_int32): + """ convert int32 values into int4 + Parameters + ---------- + a_int32 : int + + Return + ------ + a_int4 : int + """ + M, N, I, J, K, L = a_int32.shape + a_int4 = np.zeros(shape=(M, N, I, J, K, L // 8), dtype=np.int32) + for m in range(M): + for n in range(N): + for i in range(I): + for j in range(J): + for k in range(K): + for l in range(L // 8): + for a in range(8): + a_int4[m,n,i,j,k,l] = a_int4[m,n,i,j,k,l] | ((a_int32[m,n,i,j,k,l * 8 + a] & 0xf) << ((7 - a) * 4)) + return a_int4 + + a_np, w_np, b_np, c_np = get_ref_data() + + if dtype == 'int4': + a_np_tvm = a_np.reshape((batch // wmma_m, + wmma_m, + in_height, + in_width, + in_channel // wmma_k, + wmma_k)).transpose((0,2,3,4,1,5)) + w_np_tvm = w_np.reshape((kernel, + kernel, + in_channel // wmma_k, + wmma_k, + num_filter // wmma_n, + wmma_n)).transpose((0,1,2,4,5,3)) + a_np = convert_int32_into_int4(a_np_tvm) + # b_np = convert_int32_into_int4(b_np) + w_np = convert_int32_into_int4(w_np_tvm) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + if not nvcc.have_tensorcore(ctx.compute_version): + print("skip because gpu does not support Tensor Cores") + return + print("Running on target: %s" % device) + with tvm.target.create(device): + fcompute, fschedule = topi.testing.dispatch(device, _conv2d_nhwc_tensorcore_implement) + if dtype == 'float16': + C = fcompute(A, W, stride, padding, dilation, dtype, 'float') + else: + C = fcompute(A, W, stride, padding, dilation, dtype, 'int32') + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = fschedule([C]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, b, c) + else: + # print(tvm.lower(s, [A, W, C], simple_mode=True)) + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, c) + dev_module = func.imported_modules[0] + # print(dev_module.get_source()) + # warm up + evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) + func.time_evaluator(func.entry_name, ctx, number=50, repeat=20)(a, w, c) + # evaluator(a, w, c) + print('Time cost of this operator: %f ms' % (evaluator(a, w, c).mean * 1000)) + + rtol = 1e-3 + # print(c.asnumpy().shape, c_np.sum()) + tvm.testing.assert_allclose(c.asnumpy().transpose((0,4,1,2,3,5)).reshape(c_np.shape), c_np, rtol=rtol) + + # # #Tuning the performance + # import logging, sys + # logging.getLogger('autotvm').setLevel(logging.DEBUG) + # logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) + + # log_filename = "conv2d_int4_nhwc_tensorcore.log" + # tmp_log_file = log_filename + '.temp' + # num_trial = 2000 + # task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, + # padding, dilation) + # task = autotvm.create('conv2d_nhwc_tensorcore_int4.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) + # print(task.config_space) + + # measure_option = autotvm.measure_option( + # builder='local', + # runner=autotvm.LocalRunner(number=5)) + + # tuner = autotvm.tuner.XGBTuner(task) + # num_trial = min(num_trial, len(task.config_space)) + # with tvm.target.build_config(): + # tuner.tune(n_trial=num_trial, + # measure_option=measure_option, + # callbacks=[autotvm.callback.progress_bar(num_trial, prefix=task_name), + # autotvm.callback.log_to_file(tmp_log_file)]) + + # dispatch_context = autotvm.apply_history_best(tmp_log_file) + # best_config = dispatch_context.query(task.target, task.workload) + # print("\nBest config:") + # print(best_config) + + # #pick the best record to a cache file + # autotvm.record.pick_best(tmp_log_file, log_filename) + # os.remove(tmp_log_file) + + # with autotvm.apply_graph_best(log_filename): + # with tvm.target.create(device): + # func = tvm.build(s, [A, W, C], device, name="conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, + # padding, dilation)) + # evaluator = func.time_evaluator(func.entry_name, ctx, number=100, repeat=10) + # print('Time cost of this operator after tuning: %f ms' % (evaluator(a, w, c).mean * 1000)) + + check_device(devices) + + +def test_conv2d_nhwc_tensorcore(): + """Test the conv2d with tensorcore for nhwc layout""" + # verify_conv2d_nhwc(64, 64, 56, 64, 3, 1, 1) + # verify_conv2d_nhwc(64, 64, 56, 64, 1, 1, 0) + # verify_conv2d_nhwc(64, 64, 56, 128, 3, 2, 1) + # verify_conv2d_nhwc(64, 64, 56, 64, 1, 2, 0) + # verify_conv2d_nhwc(64, 128, 28, 128, 3, 1, 1) + # verify_conv2d_nhwc(64, 128, 28, 256, 3, 2, 1) + # verify_conv2d_nhwc(64, 128, 28, 256, 1, 2, 0) + # verify_conv2d_nhwc(64, 256, 14, 256, 3, 1, 1) + # verify_conv2d_nhwc(64, 256, 14, 512, 3, 2, 1) + # verify_conv2d_nhwc(64, 256, 14, 512, 1, 2, 0) + # verify_conv2d_nhwc(64, 512, 7, 512, 3, 1, 1) + + # verify_conv2d_nhwc(32, 64, 56, 64, 3, 1, 1) + # verify_conv2d_nhwc(32, 64, 56, 64, 1, 1, 0) + # verify_conv2d_nhwc(32, 64, 56, 128, 3, 2, 1) + # verify_conv2d_nhwc(32, 64, 56, 64, 1, 2, 0) + # verify_conv2d_nhwc(32, 128, 28, 128, 3, 1, 1) + # verify_conv2d_nhwc(32, 128, 28, 256, 3, 2, 1) + # verify_conv2d_nhwc(32, 128, 28, 256, 1, 2, 0) + # verify_conv2d_nhwc(32, 256, 14, 256, 3, 1, 1) + # verify_conv2d_nhwc(32, 256, 14, 512, 3, 2, 1) + # verify_conv2d_nhwc(32, 256, 14, 512, 1, 2, 0) + # verify_conv2d_nhwc(32, 512, 7, 512, 3, 1, 1) + + # verify_conv2d_nhwc(16, 64, 56, 64, 3, 1, 1) + # verify_conv2d_nhwc(16, 64, 56, 64, 1, 1, 0) + # verify_conv2d_nhwc(16, 64, 56, 128, 3, 2, 1) + # verify_conv2d_nhwc(16, 64, 56, 64, 1, 2, 0) + # verify_conv2d_nhwc(16, 128, 28, 128, 3, 1, 1) + # verify_conv2d_nhwc(16, 128, 28, 256, 3, 2, 1) + # verify_conv2d_nhwc(16, 128, 28, 256, 1, 2, 0) + # verify_conv2d_nhwc(16, 256, 14, 256, 3, 1, 1) + # verify_conv2d_nhwc(16, 256, 14, 512, 3, 2, 1) + # verify_conv2d_nhwc(16, 256, 14, 512, 1, 2, 0) + # verify_conv2d_nhwc(16, 512, 7, 512, 3, 1, 1) + + # verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) + # verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) + # verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) + # verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) + # verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1) + # verify_conv2d_nhwc(8, 128, 28, 256, 3, 2, 1) + # verify_conv2d_nhwc(8, 128, 28, 256, 1, 2, 0) + # verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1) + # verify_conv2d_nhwc(8, 256, 14, 512, 3, 2, 1) + # verify_conv2d_nhwc(8, 256, 14, 512, 1, 2, 0) + verify_conv2d_nhwc(8, 512, 7, 512, 3, 1, 1) + + + # verify_conv2d_nhwc(32, 1024, 14, 256, 1, 1, 1) + + # verify_conv2d_nhwc(16, 128, 7, 128, 7, 1, 3) + # verify_conv2d_nhwc(16, 160, 7, 160, 7, 1, 3) + + # verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_bias=True) + # verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True) + # verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True, add_bias=True) + + # verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, (3, 3, 2, 2)) + # verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, "SAME") + # verify_conv2d_nhwc(16, 48, 35, 48, 5, 1, "VALID") + # verify_conv2d_nhwc(16, 48, 56, 48, 3, 1, (1, 1, 1, 1)) + # verify_conv2d_nhwc(16, 64, 28, 64, 3, 1, (1, 1, 1, 1)) + +# import argparse +# parser = argparse.ArgumentParser() +# parser.add_argument('--brw', default=1, help="the base") +# parser.add_argument('--blw', default=1, help="the base") +# parser.add_argument('--wrt', default=1, help="the base") +# parser.add_argument('--wct', default=1, help="the base") +# parser.add_argument('--chunk', default=1, help="the base") +# parser.add_argument('--offset', default=0, help="the base") +# parser.add_argument('--vw', default=1, help="the base") +# parser.parse_args() + +if __name__ == "__main__": + # for brw in [1]: + # for blw in [2]: + # for wrt in [1]: + # for wct in [4]: + # for chunk in [1]: + # for offset in [0]: + # for vw in [1]: + # try: + # dic={'brw': brw, 'blw': blw,'wrt': wrt,'wct': wct,'chunk': chunk,'offset': offset, 'vw':vw} + test_conv2d_nhwc_tensorcore() + # except: + # pass From c666253798771bcb56621fe9c32b01d958784db5 Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Tue, 5 May 2020 14:07:04 +0000 Subject: [PATCH 03/21] update layout --- topi/python/topi/cuda/__init__.py | 1 + .../topi/cuda/conv2d_nhwc_tensorcore_int4.py | 91 ++++++---- .../test_topi_conv2d_nhwc_tensorcore_int4.py | 163 +++++++++--------- 3 files changed, 143 insertions(+), 112 deletions(-) diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 2b7a845cd9ec..ec2165dac6dd 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -46,5 +46,6 @@ from .rcnn import * from .sort import * from .conv2d_nhwc_tensorcore import * +from .conv2d_nhwc_tensorcore_int4 import * from .conv3d_ndhwc_tensorcore import * from .dense_tensorcore import * diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py index 186ed7e9fc6b..28ddf639f692 100644 --- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py +++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py @@ -24,6 +24,7 @@ from ..util import get_const_tuple, traverse_inline, simplify from ..nn.pad import pad from ..nn.util import get_pad_tuple +from topi.cuda.injective import schedule_injective_from_existing # from .tensor_intrin import intrin_wmma_load_matrix_A # from .tensor_intrin import intrin_wmma_load_matrix_W # from .tensor_intrin import intrin_wmma_store_matrix @@ -65,7 +66,7 @@ def intrin_wmma_gemm(): n = m = 8 l = 32 A = te.placeholder((n, l), name='A', dtype='int4') - B = te.placeholder((n, l), name='B', dtype='int4') + B = te.placeholder((m, l), name='B', dtype='int4') k = te.reduce_axis((0, l), name="k") C = te.compute((n, n), lambda ii, jj: @@ -135,15 +136,14 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype wmma_n = wmma_m = 8 wmma_k = 32 - batch, in_height, in_width, in_channels, wmma_m, wmma_k = get_const_tuple(Input.shape) + batch, in_height, in_width, in_channels= get_const_tuple(Input.shape) if in_dtype == 'int4': - kernel_h, kernel_w, _, num_filter, wmma_n, wmma_k = get_const_tuple(Filter.shape) + kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape) else: kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape) if in_dtype == 'int4': - pass - # assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0) + assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0) else: assert (batch % 16 == 0 and in_channels % 16 == 0 and num_filter % 16 == 0) or \ (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0) or \ @@ -164,61 +164,69 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype pad_after = [0, pad_down, pad_right, 0] # PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") # Input feature map: (N, H, W, IC, n, ic) - data_shape = (batch, + data_shape = (batch // wmma_m, in_height, in_width, - in_channels, + in_channels // wmma_k, wmma_m, wmma_k) # Kernel: (H, W, IC, OC, ic, oc) kernel_shape = (kernel_h, kernel_w, - in_channels, - out_channels, + in_channels // wmma_k, + out_channels // wmma_n, wmma_n, wmma_k) output_shape = (batch, out_height, out_width, - out_channels, - wmma_m, - wmma_n) + out_channels) # rc = te.reduce_axis((0, in_channel), name='rc') # ry = te.reduce_axis((0, kernel_h), name='ry') # rx = te.reduce_axis((0, kernel_w), name='rx') # Reduction axes kh = te.reduce_axis((0, kernel_h), name='kh') kw = te.reduce_axis((0, kernel_w), name='kw') - ic = te.reduce_axis((0, in_channels), name='ic') + ic = te.reduce_axis((0, in_channels // wmma_k), name='ic') ii = te.reduce_axis((0, wmma_k), name='ii') # Algorithm # A = te.placeholder(data_shape, name='A', dtype="int4") # W = te.placeholder(kernel_shape, name='W', dtype="int4") - Apad = te.compute( - (batch, in_height + 2 * padding, in_width + 2 * padding, in_channels, wmma_m, + A_transpose = te.compute(data_shape, + lambda n, h, w, i, nn, ii: Input[n * wmma_m + nn, h, w, i * wmma_k + ii] + ) + Filter_transpose = te.compute(kernel_shape, + lambda kh, kw, i, o, oo, ii: Filter[kh, kw, o * wmma_n + oo, i * wmma_k + ii] + ) + Apad_transpose = te.compute( + (batch // wmma_m, in_height + 2 * padding, in_width + 2 * padding, in_channels // wmma_k, wmma_m, wmma_k), lambda n, h, w, i, nn, ii: tvm.tir.if_then_else( tvm.tir.all(h >= padding, h - padding < in_height, w >= padding, w - padding < in_width), - Input[n, h - padding, w - padding, i, nn, ii], tvm.tir.const(0., "int4")), + A_transpose[n, h - padding, w - padding, i, nn, ii], tvm.tir.const(0., "int4")), name='Apad') - Conv = te.compute(output_shape, + Conv = te.compute((batch // wmma_m, out_height, out_width, out_channels // wmma_n, wmma_m, wmma_n), lambda n, h, w, o, nn, oo: te.sum( - Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("int32") * - Filter[kh, kw, ic, o, oo, ii].astype("int32"), + Apad_transpose[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("int32") * + Filter_transpose[kh, kw, ic, o, oo, ii].astype("int32"), axis=[ic, kh, kw, ii]), - name="Conv", tag="conv2d_nhwc_tensorcore_int4") + name="Conv") + Out = te.compute(output_shape, + lambda n, h, w, o: Conv[n // wmma_m, h, w, o // wmma_n, n % wmma_m, o % wmma_n], + name="Out", tag="conv2d_nhwc_tensorcore_int4") + return Out - return Conv - -def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Conv): +def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): """Schedule tensorcore template""" + Conv = s[Out].op.input_tensors[0] ic, kh, kw, ii = s[Conv].op.reduce_axis out_dtype = Conv.dtype # trans_paddata, kernel = s[Conv].op.input_tensors Apad, kernel = s[Conv].op.input_tensors - s[Apad].compute_inline() + A_transpose = s[Apad].op.input_tensors[0] + in_dtype = Apad.dtype batch, _, _, _, _, _ = get_const_tuple(Conv.shape) if in_dtype == 'int4': @@ -282,15 +290,36 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Conv): chunk = cfg["chunk"].val # offset = cfg["offset"].val vector_width = cfg["vector_width"].val - block_row_warps = 1 - block_col_warps = 8 - warp_row_tiles = 2 - warp_col_tiles = 1 - chunk = 4 - vector_width = 1 + # block_row_warps = 1 + # block_col_warps = 1 + # warp_row_tiles = 1 + # warp_col_tiles = 1 + # chunk = 1 + # vector_width = 1 # offset = 0 + with tvm.target.create('cuda'): + schedule_injective_from_existing(s, Out) + schedule_injective_from_existing(s, kernel) + s[Apad].compute_inline() + s[A_transpose].compute_inline() + # s[Out].compute_inline() + # s[kernel].compute_inline() + # if inline_apad: + # s[Apad].compute_inline() + # else: + # with tvm.target.create('cuda'): + # schedule_injective_from_existing(s, Apad) + # if inline_atranspose: + # s[A_transpose].compute_inline() + # else: + # with tvm.target.create('cuda'): + # schedule_injective_from_existing(s, A_transpose) + + + + if in_dtype == 'int4': wmma_m = wmma_n = 8 wmma_k = 32 @@ -343,7 +372,7 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Conv): ty, yo = s[AS].split(xo, nparts=block_col_warps) t = s[AS].fuse(nn, ii) to, ti = s[AS].split(t, factor=warp_size) - # ti, _t = s[AS].split(ti, factor=vector_width) + # ti, _t = s[AS].split(ti, factor=8) s[AS].bind(tx, thread_y) s[AS].bind(ty, thread_z) s[AS].bind(ti, thread_x) diff --git a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py index af135dd52a04..79bf3bbbe45b 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py @@ -71,24 +71,28 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, wmma_k = 32 in_height = in_width = in_size - A = te.placeholder((batch // wmma_m, in_height, in_width, in_channel // wmma_k, wmma_m, wmma_k), name='A', dtype=dtype) + A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype=dtype) + + # A = te.placeholder((batch // wmma_m, in_height, in_width, in_channel // wmma_k, wmma_m, wmma_k), name='A', dtype=dtype) if dtype == 'int4' or dtype == 'int8': - W = te.placeholder((kernel, kernel, in_channel // wmma_k, num_filter // wmma_n , wmma_n, wmma_k), name='W', dtype=dtype) + W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W', dtype=dtype) + # W = te.placeholder((kernel, kernel, in_channel // wmma_k, num_filter // wmma_n , wmma_n, wmma_k), name='W', dtype=dtype) else: W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype=dtype) bias = te.placeholder((1, 1, 1, num_filter), name='bias', dtype=out_dtype) - # a_shape = get_const_tuple(A.shape) - # w_shape = get_const_tuple(W.shape) + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) bias_shape = get_const_tuple(bias.shape) - a_shape = (batch, in_height, in_width, in_channel) + # a_shape = (batch, in_height, in_width, in_channel) # w_shape = (kernel, kernel, num_filter, in_channel) - w_shape = (kernel, kernel, in_channel, num_filter) + # w_shape = (kernel, kernel, in_channel, num_filter) # dtype = A.dtype @memoize("topi.tests.test_topi_conv2d_nhwc.verify_conv2d_nhwc") def get_ref_data(): + np.random.seed(5) if dtype == 'float16': a_np = np.random.uniform(size=a_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype) @@ -98,7 +102,7 @@ def get_ref_data(): a_np = np.random.randint(low=-7, high=7, size=a_shape).astype(np.int32) b_np = np.random.randint(low=-7, high=7, size=bias_shape).astype(np.int32) w_np = np.random.randint(low=-7, high=7, size=w_shape).astype(np.int32) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + dw_np = topi.testing.dilate_python(w_np.transpose((0, 1, 3, 2)), (1, 1, dilation, dilation)) elif dtype == 'int8': a_np = np.random.randint(low=1, high=7, size=a_shape).astype(dtype) w_np = np.random.randint(low=1, high=7, size=w_shape).astype(dtype) @@ -123,36 +127,34 @@ def convert_int32_into_int4(a_int32): ------ a_int4 : int """ - M, N, I, J, K, L = a_int32.shape - a_int4 = np.zeros(shape=(M, N, I, J, K, L // 8), dtype=np.int32) - for m in range(M): - for n in range(N): - for i in range(I): - for j in range(J): - for k in range(K): - for l in range(L // 8): - for a in range(8): - a_int4[m,n,i,j,k,l] = a_int4[m,n,i,j,k,l] | ((a_int32[m,n,i,j,k,l * 8 + a] & 0xf) << ((7 - a) * 4)) + I, J, K, L = a_int32.shape + a_int4 = np.zeros(shape=(I, J, K, L // 8), dtype=np.int32) + for i in range(I): + for j in range(J): + for k in range(K): + for l in range(L // 8): + for a in range(8): + a_int4[i,j,k,l] = a_int4[i,j,k,l] | ((a_int32[i,j,k,l * 8 + a] & 0xf) << ((7 - a) * 4)) return a_int4 a_np, w_np, b_np, c_np = get_ref_data() if dtype == 'int4': - a_np_tvm = a_np.reshape((batch // wmma_m, - wmma_m, - in_height, - in_width, - in_channel // wmma_k, - wmma_k)).transpose((0,2,3,4,1,5)) - w_np_tvm = w_np.reshape((kernel, - kernel, - in_channel // wmma_k, - wmma_k, - num_filter // wmma_n, - wmma_n)).transpose((0,1,2,4,5,3)) - a_np = convert_int32_into_int4(a_np_tvm) + # a_np_tvm = a_np.reshape((batch // wmma_m, + # wmma_m, + # in_height, + # in_width, + # in_channel // wmma_k, + # wmma_k)).transpose((0,2,3,4,1,5)) + # w_np_tvm = w_np.reshape((kernel, + # kernel, + # in_channel // wmma_k, + # wmma_k, + # num_filter // wmma_n, + # wmma_n)).transpose((0,1,2,4,5,3)) + a_np = convert_int32_into_int4(a_np) # b_np = convert_int32_into_int4(b_np) - w_np = convert_int32_into_int4(w_np_tvm) + w_np = convert_int32_into_int4(w_np) def check_device(device): ctx = tvm.context(device, 0) @@ -192,53 +194,52 @@ def check_device(device): # warm up evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) func.time_evaluator(func.entry_name, ctx, number=50, repeat=20)(a, w, c) - # evaluator(a, w, c) print('Time cost of this operator: %f ms' % (evaluator(a, w, c).mean * 1000)) rtol = 1e-3 - # print(c.asnumpy().shape, c_np.sum()) - tvm.testing.assert_allclose(c.asnumpy().transpose((0,4,1,2,3,5)).reshape(c_np.shape), c_np, rtol=rtol) + # print(c.asnumpy().sum(), c_np.sum()) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol) # # #Tuning the performance - # import logging, sys - # logging.getLogger('autotvm').setLevel(logging.DEBUG) - # logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) - - # log_filename = "conv2d_int4_nhwc_tensorcore.log" - # tmp_log_file = log_filename + '.temp' - # num_trial = 2000 - # task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, - # padding, dilation) - # task = autotvm.create('conv2d_nhwc_tensorcore_int4.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) - # print(task.config_space) + import logging, sys + logging.getLogger('autotvm').setLevel(logging.DEBUG) + logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) + + log_filename = "conv2d_int4_nhwc_tensorcore.log" + tmp_log_file = log_filename + '.temp' + num_trial = 1000 + task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation) + task = autotvm.create('conv2d_nhwc_tensorcore_int4.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) + print(task.config_space) - # measure_option = autotvm.measure_option( - # builder='local', - # runner=autotvm.LocalRunner(number=5)) - - # tuner = autotvm.tuner.XGBTuner(task) - # num_trial = min(num_trial, len(task.config_space)) - # with tvm.target.build_config(): - # tuner.tune(n_trial=num_trial, - # measure_option=measure_option, - # callbacks=[autotvm.callback.progress_bar(num_trial, prefix=task_name), - # autotvm.callback.log_to_file(tmp_log_file)]) - - # dispatch_context = autotvm.apply_history_best(tmp_log_file) - # best_config = dispatch_context.query(task.target, task.workload) - # print("\nBest config:") - # print(best_config) - - # #pick the best record to a cache file - # autotvm.record.pick_best(tmp_log_file, log_filename) - # os.remove(tmp_log_file) + measure_option = autotvm.measure_option( + builder='local', + runner=autotvm.LocalRunner(number=5)) + + tuner = autotvm.tuner.XGBTuner(task) + num_trial = min(num_trial, len(task.config_space)) + with tvm.target.build_config(): + tuner.tune(n_trial=num_trial, + measure_option=measure_option, + callbacks=[autotvm.callback.progress_bar(num_trial, prefix=task_name), + autotvm.callback.log_to_file(tmp_log_file)]) + + dispatch_context = autotvm.apply_history_best(tmp_log_file) + best_config = dispatch_context.query(task.target, task.workload) + print("\nBest config:") + print(best_config) + + #pick the best record to a cache file + autotvm.record.pick_best(tmp_log_file, log_filename) + os.remove(tmp_log_file) - # with autotvm.apply_graph_best(log_filename): - # with tvm.target.create(device): - # func = tvm.build(s, [A, W, C], device, name="conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, - # padding, dilation)) - # evaluator = func.time_evaluator(func.entry_name, ctx, number=100, repeat=10) - # print('Time cost of this operator after tuning: %f ms' % (evaluator(a, w, c).mean * 1000)) + with autotvm.apply_graph_best(log_filename): + with tvm.target.create(device): + func = tvm.build(s, [A, W, C], device, name="conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation)) + evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) + print('Time cost of this operator after tuning: %f ms' % (evaluator(a, w, c).mean * 1000)) check_device(devices) @@ -281,16 +282,16 @@ def test_conv2d_nhwc_tensorcore(): # verify_conv2d_nhwc(16, 256, 14, 512, 1, 2, 0) # verify_conv2d_nhwc(16, 512, 7, 512, 3, 1, 1) - # verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) - # verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) - # verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) - # verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) - # verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1) - # verify_conv2d_nhwc(8, 128, 28, 256, 3, 2, 1) - # verify_conv2d_nhwc(8, 128, 28, 256, 1, 2, 0) - # verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1) - # verify_conv2d_nhwc(8, 256, 14, 512, 3, 2, 1) - # verify_conv2d_nhwc(8, 256, 14, 512, 1, 2, 0) + verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) + verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) + verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) + verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) + verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1) + verify_conv2d_nhwc(8, 128, 28, 256, 3, 2, 1) + verify_conv2d_nhwc(8, 128, 28, 256, 1, 2, 0) + verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1) + verify_conv2d_nhwc(8, 256, 14, 512, 3, 2, 1) + verify_conv2d_nhwc(8, 256, 14, 512, 1, 2, 0) verify_conv2d_nhwc(8, 512, 7, 512, 3, 1, 1) From f5d91da37d8368bf4037af508b74d73a75e54c2e Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Thu, 7 May 2020 02:30:19 +0000 Subject: [PATCH 04/21] add inline option --- .../topi/cuda/conv2d_nhwc_tensorcore_int4.py | 143 ++++++++--------- .../test_topi_conv2d_nhwc_tensorcore_int4.py | 148 ++++++++++-------- 2 files changed, 151 insertions(+), 140 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py index 28ddf639f692..3bfa413a698e 100644 --- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py +++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py @@ -39,8 +39,6 @@ def intrin_wmma_load_matrix(scope): else: A = tvm.te.placeholder((m, l), name='A', dtype='int4') C = tvm.te.compute((m, l), lambda i, j: A[i, j], name='C') - # A = te.placeholder((n, m), name='A', dtype='int4') - # C = te.compute((m, n), lambda i, j: A[i, j], name='C') BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) @@ -138,19 +136,20 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype batch, in_height, in_width, in_channels= get_const_tuple(Input.shape) if in_dtype == 'int4': - kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape) + kernel_h, kernel_w, _, num_filter, _, _ = get_const_tuple(Filter.shape) else: kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape) - - if in_dtype == 'int4': - assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0) - else: - assert (batch % 16 == 0 and in_channels % 16 == 0 and num_filter % 16 == 0) or \ - (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0) or \ - (batch % 32 == 0 and in_channels % 16 == 0 and num_filter % 8 == 0), \ - "The shape of (batch, in_channels, num_filter) "\ - "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for fp16 and int8, "\ - "and (8, 32, 8) for int4" + num_filter = num_filter * wmma_n + # if in_dtype == 'int4': + # pass + # # assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0) + # else: + # assert (batch % 16 == 0 and in_channels % 16 == 0 and num_filter % 16 == 0) or \ + # (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0) or \ + # (batch % 32 == 0 and in_channels % 16 == 0 and num_filter % 8 == 0), \ + # "The shape of (batch, in_channels, num_filter) "\ + # "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for fp16 and int8, "\ + # "and (8, 32, 8) for int4" # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 @@ -181,23 +180,18 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype out_height, out_width, out_channels) - # rc = te.reduce_axis((0, in_channel), name='rc') - # ry = te.reduce_axis((0, kernel_h), name='ry') - # rx = te.reduce_axis((0, kernel_w), name='rx') # Reduction axes kh = te.reduce_axis((0, kernel_h), name='kh') kw = te.reduce_axis((0, kernel_w), name='kw') ic = te.reduce_axis((0, in_channels // wmma_k), name='ic') ii = te.reduce_axis((0, wmma_k), name='ii') # Algorithm - # A = te.placeholder(data_shape, name='A', dtype="int4") - # W = te.placeholder(kernel_shape, name='W', dtype="int4") A_transpose = te.compute(data_shape, lambda n, h, w, i, nn, ii: Input[n * wmma_m + nn, h, w, i * wmma_k + ii] ) - Filter_transpose = te.compute(kernel_shape, - lambda kh, kw, i, o, oo, ii: Filter[kh, kw, o * wmma_n + oo, i * wmma_k + ii] - ) + # Filter_transpose = te.compute(kernel_shape, + # lambda kh, kw, i, o, oo, ii: Filter[kh, kw, o * wmma_n + oo, i * wmma_k + ii] + # ) Apad_transpose = te.compute( (batch // wmma_m, in_height + 2 * padding, in_width + 2 * padding, in_channels // wmma_k, wmma_m, wmma_k), @@ -209,7 +203,7 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype Conv = te.compute((batch // wmma_m, out_height, out_width, out_channels // wmma_n, wmma_m, wmma_n), lambda n, h, w, o, nn, oo: te.sum( Apad_transpose[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("int32") * - Filter_transpose[kh, kw, ic, o, oo, ii].astype("int32"), + Filter[kh, kw, ic, o, oo, ii].astype("int32"), axis=[ic, kh, kw, ii]), name="Conv") Out = te.compute(output_shape, @@ -234,8 +228,6 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): else: _, _, _, out_channels, _, _ = get_const_tuple(kernel.shape) # inline the pad and dtype transform - # s[kernel].compute_inline() - # s[paddata[0]].compute_inline() block_x = te.thread_axis('blockIdx.x') block_y = te.thread_axis('blockIdx.y') @@ -251,30 +243,15 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) ConvF = s.cache_write(Conv, 'wmma.accumulator') - # todo - # if Conv.op in s.outputs: - # output = Conv - # ConvS = s.cache_read(ConvF, 'shared', [Conv]) - # OL = ConvS - # else: - # output = s.outputs[0].output(0) - # s[Conv].set_scope('shared') - # OL = Conv - # Schedule for autotvm - cfg.define_knob("block_row_warps", [1, 2, 4, 8]) - cfg.define_knob("block_col_warps", [1, 2, 4, 8]) - cfg.define_knob("warp_row_tiles", [1, 2, 4, 8]) - cfg.define_knob("warp_col_tiles", [1, 2, 4, 8]) - cfg.define_knob("chunk", [1, 2, 4, 8]) - # if in_dtype == 'int8': - # cfg.define_knob("offset", [0, 16]) - # elif in_dtype == 'int4': - # cfg.define_knob("offset", [0]) - # else: - # cfg.define_knob("offset", [0, 8]) - # cfg.define_knob("vector_width", [1, 2, 4, 8]) - cfg.define_knob("vector_width", [1, 8]) + cfg.define_knob("block_row_warps", [1, 2, 4, 8, 16]) + cfg.define_knob("block_col_warps", [1, 2, 4, 8, 16]) + cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16]) + cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16]) + cfg.define_knob("chunk", [1, 2, 4, 8, 16]) + cfg.define_knob("vector_ws", [1, 8]) + # cfg.define_knob("inline_pad", [1, 2]) + cfg.define_knob("vector_as", [1, 4, 8, 16]) # fallback support target = tvm.target.Target.current() @@ -289,36 +266,32 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): warp_col_tiles = cfg["warp_col_tiles"].val chunk = cfg["chunk"].val # offset = cfg["offset"].val - vector_width = cfg["vector_width"].val - # block_row_warps = 1 - # block_col_warps = 1 - # warp_row_tiles = 1 - # warp_col_tiles = 1 - # chunk = 1 - # vector_width = 1 - - # offset = 0 + vector_ws = cfg["vector_ws"].val + vector_as = cfg["vector_as"].val + # inline_pad = cfg["inline_pad"].val + block_row_warps = 1 + block_col_warps = 1 + warp_row_tiles = 8 + warp_col_tiles = 8 + chunk = 1 + vector_ws = 1 + vector_as = 16 + + inline_pad = 2 with tvm.target.create('cuda'): schedule_injective_from_existing(s, Out) - schedule_injective_from_existing(s, kernel) - s[Apad].compute_inline() + # schedule_injective_from_existing(s, Apad) + # schedule_injective_from_existing(s, A_transpose) + # s[Apad].compute_inline() s[A_transpose].compute_inline() # s[Out].compute_inline() - # s[kernel].compute_inline() - # if inline_apad: - # s[Apad].compute_inline() - # else: - # with tvm.target.create('cuda'): - # schedule_injective_from_existing(s, Apad) - # if inline_atranspose: - # s[A_transpose].compute_inline() - # else: - # with tvm.target.create('cuda'): - # schedule_injective_from_existing(s, A_transpose) - - + if inline_pad == 1: + s[Apad].compute_inline() + else: + with tvm.target.create('cuda'): + schedule_injective_from_existing(s, Apad) if in_dtype == 'int4': wmma_m = wmma_n = 8 @@ -342,6 +315,22 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): warp_size = 32 + # _n, _h, _w, _c = s[Out].op.axis + # block_k = s[Out].fuse(_h, _w) + # s[Out].bind(block_k, block_z) + # _n, _wmma_m = s[Out].split(_n, factor=8) + # _n, nci = s[Out].split(_n, factor=warp_row_tiles) + # block_i, _n = s[Out].split(_n, factor=block_row_warps) + # _c, _wmma_n = s[Out].split(_c, factor=8) + # _c, oci = s[Out].split(_c, factor=warp_col_tiles) + # block_j, _c = s[Out].split(_c, factor=block_col_warps) + # s[Out].reorder(block_k, block_i, block_j, _n, _c, nci, oci, _wmma_m, _wmma_n) + # s[Out].bind(block_i, block_x) + # s[Out].bind(block_j, block_y) + # s[Out].bind(_n, thread_y) + # s[Out].bind(_c, thread_z) + # s[Conv].compute_at(s[Out], Out.op.axis[2]) + nc, hc, wc, oc, nnc, ooc = Conv.op.axis block_k = s[Conv].fuse(hc, wc) s[Conv].bind(block_k, block_z) @@ -354,6 +343,7 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): s[Conv].bind(block_j, block_y) s[Conv].bind(nc, thread_y) s[Conv].bind(oc, thread_z) + # Schedule local computation s[ConvF].compute_at(s[Conv], oc) n, h, w, o, nnf, oof = ConvF.op.axis @@ -364,19 +354,18 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): s[AF].compute_at(s[ConvF], kw) s[WF].compute_at(s[ConvF], kw) - # vector_width=8 # Schedule for A's share memory s[AS].compute_at(s[ConvF], kh) n, h, w, i, nn, ii = AS.op.axis tx, xo = s[AS].split(n, nparts=block_row_warps) ty, yo = s[AS].split(xo, nparts=block_col_warps) t = s[AS].fuse(nn, ii) - to, ti = s[AS].split(t, factor=warp_size) - # ti, _t = s[AS].split(ti, factor=8) + to, ti = s[AS].split(t, nparts=warp_size) + ti, _t = s[AS].split(ti, factor=vector_as) s[AS].bind(tx, thread_y) s[AS].bind(ty, thread_z) - s[AS].bind(ti, thread_x) - # s[AS].vectorize(ti) + s[AS].bind(to, thread_x) + s[AS].vectorize(_t) # Schedule for W's share memory s[WS].compute_at(s[ConvF], kh) @@ -385,7 +374,7 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): ty, yo = s[WS].split(xo, nparts=block_col_warps) t = s[WS].fuse(ii, oo) to, ti = s[WS].split(t, nparts=warp_size) - ti, _t = s[WS].split(ti, factor=vector_width) + ti, _t = s[WS].split(ti, factor=vector_ws) s[WS].bind(tx, thread_y) s[WS].bind(ty, thread_z) s[WS].bind(to, thread_x) diff --git a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py index 79bf3bbbe45b..4a9460702a14 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py @@ -75,8 +75,8 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, # A = te.placeholder((batch // wmma_m, in_height, in_width, in_channel // wmma_k, wmma_m, wmma_k), name='A', dtype=dtype) if dtype == 'int4' or dtype == 'int8': - W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W', dtype=dtype) - # W = te.placeholder((kernel, kernel, in_channel // wmma_k, num_filter // wmma_n , wmma_n, wmma_k), name='W', dtype=dtype) + # W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W', dtype=dtype) + W = te.placeholder((kernel, kernel, in_channel // wmma_k, num_filter // wmma_n , wmma_n, wmma_k), name='W', dtype=dtype) else: W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype=dtype) @@ -86,7 +86,7 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, w_shape = get_const_tuple(W.shape) bias_shape = get_const_tuple(bias.shape) # a_shape = (batch, in_height, in_width, in_channel) - # w_shape = (kernel, kernel, num_filter, in_channel) + w_shape = (kernel, kernel, in_channel, num_filter) # w_shape = (kernel, kernel, in_channel, num_filter) # dtype = A.dtype @@ -99,10 +99,10 @@ def get_ref_data(): b_np = np.random.uniform(size=bias_shape).astype(out_dtype) dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) elif dtype == 'int4': - a_np = np.random.randint(low=-7, high=7, size=a_shape).astype(np.int32) - b_np = np.random.randint(low=-7, high=7, size=bias_shape).astype(np.int32) - w_np = np.random.randint(low=-7, high=7, size=w_shape).astype(np.int32) - dw_np = topi.testing.dilate_python(w_np.transpose((0, 1, 3, 2)), (1, 1, dilation, dilation)) + a_np = np.random.randint(low=1, high=7, size=a_shape).astype(np.int32) + b_np = np.random.randint(low=1, high=7, size=bias_shape).astype(np.int32) + w_np = np.random.randint(low=1, high=7, size=w_shape).astype(np.int32) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) elif dtype == 'int8': a_np = np.random.randint(low=1, high=7, size=a_shape).astype(dtype) w_np = np.random.randint(low=1, high=7, size=w_shape).astype(dtype) @@ -136,6 +136,27 @@ def convert_int32_into_int4(a_int32): for a in range(8): a_int4[i,j,k,l] = a_int4[i,j,k,l] | ((a_int32[i,j,k,l * 8 + a] & 0xf) << ((7 - a) * 4)) return a_int4 + def convert_int32_into_int4_shape6(a_int32): + """ convert int32 values into int4 + Parameters + ---------- + a_int32 : int + + Return + ------ + a_int4 : int + """ + M, N, I, J, K, L = a_int32.shape + a_int4 = np.zeros(shape=(M, N, I, J, K, L // 8), dtype=np.int32) + for m in range(M): + for n in range (N): + for i in range(I): + for j in range(J): + for k in range(K): + for l in range(L // 8): + for a in range(8): + a_int4[m,n,i,j,k,l] = a_int4[m,n,i,j,k,l] | ((a_int32[m,n,i,j,k,l * 8 + a] & 0xf) << ((7 - a) * 4)) + return a_int4 a_np, w_np, b_np, c_np = get_ref_data() @@ -146,15 +167,15 @@ def convert_int32_into_int4(a_int32): # in_width, # in_channel // wmma_k, # wmma_k)).transpose((0,2,3,4,1,5)) - # w_np_tvm = w_np.reshape((kernel, - # kernel, - # in_channel // wmma_k, - # wmma_k, - # num_filter // wmma_n, - # wmma_n)).transpose((0,1,2,4,5,3)) + w_np = w_np.reshape((kernel, + kernel, + in_channel // wmma_k, + wmma_k, + num_filter // wmma_n, + wmma_n)).transpose((0,1,2,4,5,3)) a_np = convert_int32_into_int4(a_np) # b_np = convert_int32_into_int4(b_np) - w_np = convert_int32_into_int4(w_np) + w_np = convert_int32_into_int4_shape6(w_np) def check_device(device): ctx = tvm.context(device, 0) @@ -185,7 +206,7 @@ def check_device(device): batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, b, c) else: - # print(tvm.lower(s, [A, W, C], simple_mode=True)) + print(tvm.lower(s, [A, W, C], simple_mode=True)) func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, c) @@ -193,7 +214,7 @@ def check_device(device): # print(dev_module.get_source()) # warm up evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) - func.time_evaluator(func.entry_name, ctx, number=50, repeat=20)(a, w, c) + evaluator(a, w, c) print('Time cost of this operator: %f ms' % (evaluator(a, w, c).mean * 1000)) rtol = 1e-3 @@ -201,45 +222,46 @@ def check_device(device): tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol) # # #Tuning the performance - import logging, sys - logging.getLogger('autotvm').setLevel(logging.DEBUG) - logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) - - log_filename = "conv2d_int4_nhwc_tensorcore.log" - tmp_log_file = log_filename + '.temp' - num_trial = 1000 - task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, - padding, dilation) - task = autotvm.create('conv2d_nhwc_tensorcore_int4.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) - print(task.config_space) + # import logging, sys + # logging.getLogger('autotvm').setLevel(logging.DEBUG) + # logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) + + # log_filename = "conv2d_int4_nhwc_tensorcore_injectpad_%d_%d_%d_%d_%d_%d_%d_%d.log" % (batch, in_channel, in_size, num_filter, kernel, stride, + # padding, dilation) + # tmp_log_file = log_filename + '.temp' + # num_trial = 2000 + # task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, + # padding, dilation) + # task = autotvm.create('conv2d_nhwc_tensorcore_int4.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) + # print(task.config_space) - measure_option = autotvm.measure_option( - builder='local', - runner=autotvm.LocalRunner(number=5)) - - tuner = autotvm.tuner.XGBTuner(task) - num_trial = min(num_trial, len(task.config_space)) - with tvm.target.build_config(): - tuner.tune(n_trial=num_trial, - measure_option=measure_option, - callbacks=[autotvm.callback.progress_bar(num_trial, prefix=task_name), - autotvm.callback.log_to_file(tmp_log_file)]) - - dispatch_context = autotvm.apply_history_best(tmp_log_file) - best_config = dispatch_context.query(task.target, task.workload) - print("\nBest config:") - print(best_config) - - #pick the best record to a cache file - autotvm.record.pick_best(tmp_log_file, log_filename) - os.remove(tmp_log_file) + # measure_option = autotvm.measure_option( + # builder='local', + # runner=autotvm.LocalRunner(number=5)) + + # tuner = autotvm.tuner.XGBTuner(task) + # num_trial = min(num_trial, len(task.config_space)) + # with tvm.target.build_config(): + # tuner.tune(n_trial=num_trial, + # measure_option=measure_option, + # callbacks=[autotvm.callback.progress_bar(num_trial, prefix=task_name), + # autotvm.callback.log_to_file(tmp_log_file)]) + + # dispatch_context = autotvm.apply_history_best(tmp_log_file) + # best_config = dispatch_context.query(task.target, task.workload) + # print("\nBest config:") + # print(best_config) + + # #pick the best record to a cache file + # autotvm.record.pick_best(tmp_log_file, log_filename) + # os.remove(tmp_log_file) - with autotvm.apply_graph_best(log_filename): - with tvm.target.create(device): - func = tvm.build(s, [A, W, C], device, name="conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, - padding, dilation)) - evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) - print('Time cost of this operator after tuning: %f ms' % (evaluator(a, w, c).mean * 1000)) + # with autotvm.apply_graph_best(log_filename): + # with tvm.target.create(device): + # func = tvm.build(s, [A, W, C], device, name="conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, + # padding, dilation)) + # evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) + # print('Time cost of this operator after tuning: %f ms' % (evaluator(a, w, c).mean * 1000)) check_device(devices) @@ -283,16 +305,16 @@ def test_conv2d_nhwc_tensorcore(): # verify_conv2d_nhwc(16, 512, 7, 512, 3, 1, 1) verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) - verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) - verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) - verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) - verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1) - verify_conv2d_nhwc(8, 128, 28, 256, 3, 2, 1) - verify_conv2d_nhwc(8, 128, 28, 256, 1, 2, 0) - verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1) - verify_conv2d_nhwc(8, 256, 14, 512, 3, 2, 1) - verify_conv2d_nhwc(8, 256, 14, 512, 1, 2, 0) - verify_conv2d_nhwc(8, 512, 7, 512, 3, 1, 1) + # verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) + # verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) + # verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) + # verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1) + # verify_conv2d_nhwc(8, 128, 28, 256, 3, 2, 1) + # verify_conv2d_nhwc(8, 128, 28, 256, 1, 2, 0) + # verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1) + # verify_conv2d_nhwc(8, 256, 14, 512, 3, 2, 1) + # verify_conv2d_nhwc(8, 256, 14, 512, 1, 2, 0) + # verify_conv2d_nhwc(8, 512, 7, 512, 3, 1, 1) # verify_conv2d_nhwc(32, 1024, 14, 256, 1, 1, 1) From 2797ad0c12cec1e8342799ebaea96499f112249a Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Thu, 7 May 2020 02:42:46 +0000 Subject: [PATCH 05/21] clean code --- .../test_topi_conv2d_nhwc_tensorcore_int4.py | 41 +++++-------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py index 4a9460702a14..fadd96ae7ec9 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py @@ -305,16 +305,16 @@ def test_conv2d_nhwc_tensorcore(): # verify_conv2d_nhwc(16, 512, 7, 512, 3, 1, 1) verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) - # verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) - # verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) - # verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) - # verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1) - # verify_conv2d_nhwc(8, 128, 28, 256, 3, 2, 1) - # verify_conv2d_nhwc(8, 128, 28, 256, 1, 2, 0) - # verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1) - # verify_conv2d_nhwc(8, 256, 14, 512, 3, 2, 1) - # verify_conv2d_nhwc(8, 256, 14, 512, 1, 2, 0) - # verify_conv2d_nhwc(8, 512, 7, 512, 3, 1, 1) + verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) + verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) + verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) + verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1) + verify_conv2d_nhwc(8, 128, 28, 256, 3, 2, 1) + verify_conv2d_nhwc(8, 128, 28, 256, 1, 2, 0) + verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1) + verify_conv2d_nhwc(8, 256, 14, 512, 3, 2, 1) + verify_conv2d_nhwc(8, 256, 14, 512, 1, 2, 0) + verify_conv2d_nhwc(8, 512, 7, 512, 3, 1, 1) # verify_conv2d_nhwc(32, 1024, 14, 256, 1, 1, 1) @@ -332,27 +332,6 @@ def test_conv2d_nhwc_tensorcore(): # verify_conv2d_nhwc(16, 48, 56, 48, 3, 1, (1, 1, 1, 1)) # verify_conv2d_nhwc(16, 64, 28, 64, 3, 1, (1, 1, 1, 1)) -# import argparse -# parser = argparse.ArgumentParser() -# parser.add_argument('--brw', default=1, help="the base") -# parser.add_argument('--blw', default=1, help="the base") -# parser.add_argument('--wrt', default=1, help="the base") -# parser.add_argument('--wct', default=1, help="the base") -# parser.add_argument('--chunk', default=1, help="the base") -# parser.add_argument('--offset', default=0, help="the base") -# parser.add_argument('--vw', default=1, help="the base") -# parser.parse_args() if __name__ == "__main__": - # for brw in [1]: - # for blw in [2]: - # for wrt in [1]: - # for wct in [4]: - # for chunk in [1]: - # for offset in [0]: - # for vw in [1]: - # try: - # dic={'brw': brw, 'blw': blw,'wrt': wrt,'wct': wct,'chunk': chunk,'offset': offset, 'vw':vw} test_conv2d_nhwc_tensorcore() - # except: - # pass From be92a0f5b7363ce7d7729834c5a4066636aced5b Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Tue, 12 May 2020 14:05:19 +0000 Subject: [PATCH 06/21] increase search space --- .../topi/cuda/conv2d_nhwc_tensorcore_int4.py | 118 ++++++++---------- .../test_topi_conv2d_nhwc_tensorcore_int4.py | 44 ++++--- 2 files changed, 78 insertions(+), 84 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py index 3bfa413a698e..aac9dda71f66 100644 --- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py +++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py @@ -30,18 +30,15 @@ # from .tensor_intrin import intrin_wmma_store_matrix # from .tensor_intrin import intrin_wmma_gemm -def intrin_wmma_load_matrix(scope): - n = m = 8 - l = 32 +def intrin_wmma_load_matrix(scope, m, n, l, in_dtype): if scope == 'wmma.matrix_a': - A = tvm.te.placeholder((n, l), name='A', dtype='int4') - C = tvm.te.compute((n, l), lambda i, j: A[i, j], name='C') - else: - A = tvm.te.placeholder((m, l), name='A', dtype='int4') + A = tvm.te.placeholder((m, l), name='A', dtype=in_dtype) C = tvm.te.compute((m, l), lambda i, j: A[i, j], name='C') + else: + A = tvm.te.placeholder((n, l), name='A', dtype=in_dtype) + C = tvm.te.compute((n, l), lambda i, j: A[i, j], name='C') BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) - def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() @@ -60,13 +57,11 @@ def intrin_func(ins, outs): return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) -def intrin_wmma_gemm(): - n = m = 8 - l = 32 - A = te.placeholder((n, l), name='A', dtype='int4') - B = te.placeholder((m, l), name='B', dtype='int4') +def intrin_wmma_gemm(m, n, l, in_dtype): + A = te.placeholder((m, l), name='A', dtype=in_dtype) + B = te.placeholder((n, l), name='B', dtype=in_dtype) k = te.reduce_axis((0, l), name="k") - C = te.compute((n, n), + C = te.compute((m, n), lambda ii, jj: te.sum(A[ii, k].astype('int32') * B[jj, k].astype('int32'), axis=k), name='C') @@ -97,12 +92,10 @@ def update(): return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) -def intrin_wmma_store_matrix(): - n = m = 8 - l = 32 - A = te.placeholder((n, m), name='A', dtype='int32') +def intrin_wmma_store_matrix(m, n, l): + A = te.placeholder((m, n), name='A', dtype='int32') BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=64) - C = te.compute((n, m), lambda i, j: A[i, j], name='C') + C = te.compute((m, n), lambda i, j: A[i, j], name='C') BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=64) def intrin_func(ins, outs): @@ -131,14 +124,19 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype else: dilation_h, dilation_w = dilation - wmma_n = wmma_m = 8 - wmma_k = 32 + if in_dtype == 'int4': + wmma_n = wmma_m = 8 + wmma_k = 32 + else: + wmma_m = 16 + wmma_n = 16 + wmma_k = 16 batch, in_height, in_width, in_channels= get_const_tuple(Input.shape) if in_dtype == 'int4': kernel_h, kernel_w, _, num_filter, _, _ = get_const_tuple(Filter.shape) else: - kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape) + kernel_h, kernel_w, _, num_filter, _, _ = get_const_tuple(Filter.shape) num_filter = num_filter * wmma_n # if in_dtype == 'int4': # pass @@ -161,7 +159,10 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) pad_before = [0, pad_top, pad_left, 0] pad_after = [0, pad_down, pad_right, 0] - # PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + TransPaddedInput = te.compute( + PaddedInput.shape, + lambda n, h, w, c: PaddedInput[n, h, w, c].astype(in_dtype)) # Input feature map: (N, H, W, IC, n, ic) data_shape = (batch // wmma_m, in_height, @@ -187,7 +188,7 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype ii = te.reduce_axis((0, wmma_k), name='ii') # Algorithm A_transpose = te.compute(data_shape, - lambda n, h, w, i, nn, ii: Input[n * wmma_m + nn, h, w, i * wmma_k + ii] + lambda n, h, w, i, nn, ii: Input[n * wmma_m + nn, h, w, i * wmma_k + ii].astype(in_dtype) ) # Filter_transpose = te.compute(kernel_shape, # lambda kh, kw, i, o, oo, ii: Filter[kh, kw, o * wmma_n + oo, i * wmma_k + ii] @@ -198,7 +199,7 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype lambda n, h, w, i, nn, ii: tvm.tir.if_then_else( tvm.tir.all(h >= padding, h - padding < in_height, w >= padding, w - padding < in_width), - A_transpose[n, h - padding, w - padding, i, nn, ii], tvm.tir.const(0., "int4")), + A_transpose[n, h - padding, w - padding, i, nn, ii], tvm.tir.const(0., in_dtype)), name='Apad') Conv = te.compute((batch // wmma_m, out_height, out_width, out_channels // wmma_n, wmma_m, wmma_n), lambda n, h, w, o, nn, oo: te.sum( @@ -226,7 +227,7 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): if in_dtype == 'int4': _, _, out_channels, _, _, _ = get_const_tuple(kernel.shape) else: - _, _, _, out_channels, _, _ = get_const_tuple(kernel.shape) + _, _, out_channels, _, _, _ = get_const_tuple(kernel.shape) # inline the pad and dtype transform block_x = te.thread_axis('blockIdx.x') @@ -244,14 +245,15 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): ConvF = s.cache_write(Conv, 'wmma.accumulator') # Schedule for autotvm - cfg.define_knob("block_row_warps", [1, 2, 4, 8, 16]) - cfg.define_knob("block_col_warps", [1, 2, 4, 8, 16]) - cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16]) - cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16]) - cfg.define_knob("chunk", [1, 2, 4, 8, 16]) + cfg.define_knob("block_row_warps", [1, 2]) + cfg.define_knob("block_col_warps", [1, 2]) + cfg.define_knob("warp_row_tiles", [1, 2, 4, 8]) + cfg.define_knob("warp_col_tiles", [1, 2, 4, 8]) + cfg.define_knob("chunk", [1, 2, 4, 8]) cfg.define_knob("vector_ws", [1, 8]) # cfg.define_knob("inline_pad", [1, 2]) cfg.define_knob("vector_as", [1, 4, 8, 16]) + cfg.define_knob("split_block_k", [1, 2, 4, 8]) # fallback support target = tvm.target.Target.current() @@ -268,14 +270,15 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): # offset = cfg["offset"].val vector_ws = cfg["vector_ws"].val vector_as = cfg["vector_as"].val - # inline_pad = cfg["inline_pad"].val + split_block_k = cfg["split_block_k"].val block_row_warps = 1 block_col_warps = 1 warp_row_tiles = 8 - warp_col_tiles = 8 - chunk = 1 + warp_col_tiles = 4 + chunk = 4 vector_ws = 1 vector_as = 16 + split_block_k = 8 inline_pad = 2 @@ -297,14 +300,14 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): wmma_m = wmma_n = 8 wmma_k = 32 else: - if (batch % 16 == 0 and out_channels % 16 == 0): - cfg.define_knob("wmma_m", [16, 8, 32]) - elif (batch % 8 == 0 and out_channels % 32 == 0): - cfg.define_knob("wmma_m", [8, 16, 32]) - elif (batch % 32 == 0 and out_channels % 8 == 0): - cfg.define_knob("wmma_m", [32, 16, 8]) - wmma_m = cfg["wmma_m"].val - # wmma_m = 16 + # if (batch % 16 == 0 and out_channels % 16 == 0): + # cfg.define_knob("wmma_m", [16, 8, 32]) + # elif (batch % 8 == 0 and out_channels % 32 == 0): + # cfg.define_knob("wmma_m", [8, 16, 32]) + # elif (batch % 32 == 0 and out_channels % 8 == 0): + # cfg.define_knob("wmma_m", [32, 16, 8]) + # wmma_m = cfg["wmma_m"].val + wmma_m = 16 wmma_k = 16 if wmma_m == 16: wmma_n = 16 @@ -315,30 +318,15 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): warp_size = 32 - # _n, _h, _w, _c = s[Out].op.axis - # block_k = s[Out].fuse(_h, _w) - # s[Out].bind(block_k, block_z) - # _n, _wmma_m = s[Out].split(_n, factor=8) - # _n, nci = s[Out].split(_n, factor=warp_row_tiles) - # block_i, _n = s[Out].split(_n, factor=block_row_warps) - # _c, _wmma_n = s[Out].split(_c, factor=8) - # _c, oci = s[Out].split(_c, factor=warp_col_tiles) - # block_j, _c = s[Out].split(_c, factor=block_col_warps) - # s[Out].reorder(block_k, block_i, block_j, _n, _c, nci, oci, _wmma_m, _wmma_n) - # s[Out].bind(block_i, block_x) - # s[Out].bind(block_j, block_y) - # s[Out].bind(_n, thread_y) - # s[Out].bind(_c, thread_z) - # s[Conv].compute_at(s[Out], Out.op.axis[2]) - nc, hc, wc, oc, nnc, ooc = Conv.op.axis block_k = s[Conv].fuse(hc, wc) - s[Conv].bind(block_k, block_z) + block_k, sub_block_k = s[Conv].split(block_k, factor=split_block_k) nc, nci = s[Conv].split(nc, factor=warp_row_tiles) block_i, nc = s[Conv].split(nc, factor=block_row_warps) oc, oci = s[Conv].split(oc, factor=warp_col_tiles) block_j, oc = s[Conv].split(oc, factor=block_col_warps) - s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc) + s[Conv].reorder(block_k, block_i, block_j, sub_block_k, nc, oc, nci, oci, nnc, ooc) + s[Conv].bind(block_k, block_z) s[Conv].bind(block_i, block_x) s[Conv].bind(block_j, block_y) s[Conv].bind(nc, thread_y) @@ -380,17 +368,17 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): s[WS].bind(to, thread_x) s[WS].vectorize(ti) - s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a')) - s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b')) - s[Conv].tensorize(nnc, intrin_wmma_store_matrix()) - s[ConvF].tensorize(nnf, intrin_wmma_gemm()) + s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a', wmma_m, wmma_n, wmma_k, in_dtype)) + s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b', wmma_m, wmma_n, wmma_k, in_dtype)) + s[Conv].tensorize(nnc, intrin_wmma_store_matrix(wmma_m, wmma_n, wmma_k)) + s[ConvF].tensorize(nnf, intrin_wmma_gemm(wmma_m, wmma_n, wmma_k, in_dtype)) N, OH, OW, CO, nn, mm = get_const_tuple(Conv.shape) if in_dtype == 'int4': KH, KW, _, CI, _, ci = get_const_tuple(kernel.shape) else: - KH, KW, CI, _ = get_const_tuple(kernel.shape) + KH, KW, _, CI, _, ci = get_const_tuple(kernel.shape) cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * ci * nn * mm) diff --git a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py index fadd96ae7ec9..68aa9f65c9dc 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py @@ -67,8 +67,13 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, # choose dtype from int4, int8 and float16 dtype = 'int4' out_dtype = 'int32' - wmma_n = wmma_m = 8 - wmma_k = 32 + if dtype == 'int4': + wmma_n = wmma_m = 8 + wmma_k = 32 + else: + wmma_m = 16 + wmma_n = 16 + wmma_k = 16 in_height = in_width = in_size A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype=dtype) @@ -78,7 +83,7 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, # W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W', dtype=dtype) W = te.placeholder((kernel, kernel, in_channel // wmma_k, num_filter // wmma_n , wmma_n, wmma_k), name='W', dtype=dtype) else: - W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype=dtype) + W = te.placeholder((kernel, kernel, in_channel // wmma_k, num_filter // wmma_n , wmma_n, wmma_k), name='W', dtype=dtype) bias = te.placeholder((1, 1, 1, num_filter), name='bias', dtype=out_dtype) @@ -160,7 +165,7 @@ def convert_int32_into_int4_shape6(a_int32): a_np, w_np, b_np, c_np = get_ref_data() - if dtype == 'int4': + if dtype == 'int4' or dtype == 'int8': # a_np_tvm = a_np.reshape((batch // wmma_m, # wmma_m, # in_height, @@ -173,9 +178,10 @@ def convert_int32_into_int4_shape6(a_int32): wmma_k, num_filter // wmma_n, wmma_n)).transpose((0,1,2,4,5,3)) - a_np = convert_int32_into_int4(a_np) - # b_np = convert_int32_into_int4(b_np) - w_np = convert_int32_into_int4_shape6(w_np) + if dtype == 'int4': + a_np = convert_int32_into_int4(a_np) + # b_np = convert_int32_into_int4(b_np) + w_np = convert_int32_into_int4_shape6(w_np) def check_device(device): ctx = tvm.context(device, 0) @@ -226,10 +232,10 @@ def check_device(device): # logging.getLogger('autotvm').setLevel(logging.DEBUG) # logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) - # log_filename = "conv2d_int4_nhwc_tensorcore_injectpad_%d_%d_%d_%d_%d_%d_%d_%d.log" % (batch, in_channel, in_size, num_filter, kernel, stride, + # log_filename = "conv2d_int8_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d.log" % (batch, in_channel, in_size, num_filter, kernel, stride, # padding, dilation) # tmp_log_file = log_filename + '.temp' - # num_trial = 2000 + # num_trial = 1000 # task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, # padding, dilation) # task = autotvm.create('conv2d_nhwc_tensorcore_int4.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) @@ -305,16 +311,16 @@ def test_conv2d_nhwc_tensorcore(): # verify_conv2d_nhwc(16, 512, 7, 512, 3, 1, 1) verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) - verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) - verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) - verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) - verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1) - verify_conv2d_nhwc(8, 128, 28, 256, 3, 2, 1) - verify_conv2d_nhwc(8, 128, 28, 256, 1, 2, 0) - verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1) - verify_conv2d_nhwc(8, 256, 14, 512, 3, 2, 1) - verify_conv2d_nhwc(8, 256, 14, 512, 1, 2, 0) - verify_conv2d_nhwc(8, 512, 7, 512, 3, 1, 1) + # verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) + # verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) + # verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) + # verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1) + # verify_conv2d_nhwc(8, 128, 28, 256, 3, 2, 1) + # verify_conv2d_nhwc(8, 128, 28, 256, 1, 2, 0) + # verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1) + # verify_conv2d_nhwc(8, 256, 14, 512, 3, 2, 1) + # verify_conv2d_nhwc(8, 256, 14, 512, 1, 2, 0) + # verify_conv2d_nhwc(8, 512, 7, 512, 3, 1, 1) # verify_conv2d_nhwc(32, 1024, 14, 256, 1, 1, 1) From d0bd2ad5383322e10edbd8331132dfcbf6133404 Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Thu, 14 May 2020 01:30:26 +0000 Subject: [PATCH 07/21] fix kernel shape --- .../topi/cuda/conv2d_nhwc_tensorcore_int4.py | 65 ++++--- .../test_topi_conv2d_nhwc_tensorcore_int4.py | 172 ++++++++---------- 2 files changed, 105 insertions(+), 132 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py index aac9dda71f66..62d07af02a30 100644 --- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py +++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py @@ -134,10 +134,9 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype batch, in_height, in_width, in_channels= get_const_tuple(Input.shape) if in_dtype == 'int4': - kernel_h, kernel_w, _, num_filter, _, _ = get_const_tuple(Filter.shape) + kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape) else: kernel_h, kernel_w, _, num_filter, _, _ = get_const_tuple(Filter.shape) - num_filter = num_filter * wmma_n # if in_dtype == 'int4': # pass # # assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0) @@ -157,12 +156,6 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype out_channels = num_filter out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) - pad_before = [0, pad_top, pad_left, 0] - pad_after = [0, pad_down, pad_right, 0] - PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") - TransPaddedInput = te.compute( - PaddedInput.shape, - lambda n, h, w, c: PaddedInput[n, h, w, c].astype(in_dtype)) # Input feature map: (N, H, W, IC, n, ic) data_shape = (batch // wmma_m, in_height, @@ -170,17 +163,17 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype in_channels // wmma_k, wmma_m, wmma_k) - # Kernel: (H, W, IC, OC, ic, oc) + # Kernel: (H, W, OC, IC, ic, oc) kernel_shape = (kernel_h, kernel_w, - in_channels // wmma_k, out_channels // wmma_n, + in_channels // wmma_k, wmma_n, wmma_k) output_shape = (batch, out_height, out_width, - out_channels) + out_channels) # Reduction axes kh = te.reduce_axis((0, kernel_h), name='kh') kw = te.reduce_axis((0, kernel_w), name='kw') @@ -190,9 +183,9 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype A_transpose = te.compute(data_shape, lambda n, h, w, i, nn, ii: Input[n * wmma_m + nn, h, w, i * wmma_k + ii].astype(in_dtype) ) - # Filter_transpose = te.compute(kernel_shape, - # lambda kh, kw, i, o, oo, ii: Filter[kh, kw, o * wmma_n + oo, i * wmma_k + ii] - # ) + Filter_transpose = te.compute(kernel_shape, + lambda kh, kw, o, i, oo, ii: Filter[kh, kw, o * wmma_n + oo, i * wmma_k + ii].astype(in_dtype) + ) Apad_transpose = te.compute( (batch // wmma_m, in_height + 2 * padding, in_width + 2 * padding, in_channels // wmma_k, wmma_m, wmma_k), @@ -204,7 +197,7 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype Conv = te.compute((batch // wmma_m, out_height, out_width, out_channels // wmma_n, wmma_m, wmma_n), lambda n, h, w, o, nn, oo: te.sum( Apad_transpose[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("int32") * - Filter[kh, kw, ic, o, oo, ii].astype("int32"), + Filter_transpose[kh, kw, o, ic, oo, ii].astype("int32"), axis=[ic, kh, kw, ii]), name="Conv") Out = te.compute(output_shape, @@ -221,13 +214,13 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): # trans_paddata, kernel = s[Conv].op.input_tensors Apad, kernel = s[Conv].op.input_tensors A_transpose = s[Apad].op.input_tensors[0] - + in_dtype = Apad.dtype batch, _, _, _, _, _ = get_const_tuple(Conv.shape) - if in_dtype == 'int4': - _, _, out_channels, _, _, _ = get_const_tuple(kernel.shape) - else: - _, _, out_channels, _, _, _ = get_const_tuple(kernel.shape) + # if in_dtype == 'int4': + # _, _, _, out_channels, _, _ = get_const_tuple(kernel.shape) + # else: + # _, _, out_channels, _, _, _ = get_const_tuple(kernel.shape) # inline the pad and dtype transform block_x = te.thread_axis('blockIdx.x') @@ -247,11 +240,11 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): # Schedule for autotvm cfg.define_knob("block_row_warps", [1, 2]) cfg.define_knob("block_col_warps", [1, 2]) - cfg.define_knob("warp_row_tiles", [1, 2, 4, 8]) - cfg.define_knob("warp_col_tiles", [1, 2, 4, 8]) + cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16]) + cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16]) cfg.define_knob("chunk", [1, 2, 4, 8]) cfg.define_knob("vector_ws", [1, 8]) - # cfg.define_knob("inline_pad", [1, 2]) + cfg.define_knob("inline_pad", [0, 1]) cfg.define_knob("vector_as", [1, 4, 8, 16]) cfg.define_knob("split_block_k", [1, 2, 4, 8]) @@ -271,26 +264,30 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): vector_ws = cfg["vector_ws"].val vector_as = cfg["vector_as"].val split_block_k = cfg["split_block_k"].val - block_row_warps = 1 - block_col_warps = 1 - warp_row_tiles = 8 - warp_col_tiles = 4 - chunk = 4 - vector_ws = 1 - vector_as = 16 - split_block_k = 8 - - inline_pad = 2 + inline_pad = cfg["inline_pad"].val + # block_row_warps = 1 + # block_col_warps = 2 + # warp_row_tiles = 32 + # warp_col_tiles = 16 + # chunk = 8 + # vector_ws = 1 + # inline_pad = 1 + # vector_as = 1 + # split_block_k = 1 + + # inline_pad = 0 with tvm.target.create('cuda'): schedule_injective_from_existing(s, Out) + schedule_injective_from_existing(s, kernel) # schedule_injective_from_existing(s, Apad) # schedule_injective_from_existing(s, A_transpose) + # s[kernel].compute_inline() # s[Apad].compute_inline() s[A_transpose].compute_inline() # s[Out].compute_inline() - if inline_pad == 1: + if inline_pad: s[Apad].compute_inline() else: with tvm.target.create('cuda'): diff --git a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py index 68aa9f65c9dc..2781a853725b 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py @@ -80,8 +80,8 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, # A = te.placeholder((batch // wmma_m, in_height, in_width, in_channel // wmma_k, wmma_m, wmma_k), name='A', dtype=dtype) if dtype == 'int4' or dtype == 'int8': - # W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W', dtype=dtype) - W = te.placeholder((kernel, kernel, in_channel // wmma_k, num_filter // wmma_n , wmma_n, wmma_k), name='W', dtype=dtype) + W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W', dtype=dtype) + # W = te.placeholder((kernel, kernel, in_channel // wmma_k, num_filter // wmma_n , wmma_n, wmma_k), name='W', dtype=dtype) else: W = te.placeholder((kernel, kernel, in_channel // wmma_k, num_filter // wmma_n , wmma_n, wmma_k), name='W', dtype=dtype) @@ -90,9 +90,6 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, a_shape = get_const_tuple(A.shape) w_shape = get_const_tuple(W.shape) bias_shape = get_const_tuple(bias.shape) - # a_shape = (batch, in_height, in_width, in_channel) - w_shape = (kernel, kernel, in_channel, num_filter) - # w_shape = (kernel, kernel, in_channel, num_filter) # dtype = A.dtype @memoize("topi.tests.test_topi_conv2d_nhwc.verify_conv2d_nhwc") @@ -107,7 +104,7 @@ def get_ref_data(): a_np = np.random.randint(low=1, high=7, size=a_shape).astype(np.int32) b_np = np.random.randint(low=1, high=7, size=bias_shape).astype(np.int32) w_np = np.random.randint(low=1, high=7, size=w_shape).astype(np.int32) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)).transpose((0, 1, 3, 2)) elif dtype == 'int8': a_np = np.random.randint(low=1, high=7, size=a_shape).astype(dtype) w_np = np.random.randint(low=1, high=7, size=w_shape).astype(dtype) @@ -141,27 +138,6 @@ def convert_int32_into_int4(a_int32): for a in range(8): a_int4[i,j,k,l] = a_int4[i,j,k,l] | ((a_int32[i,j,k,l * 8 + a] & 0xf) << ((7 - a) * 4)) return a_int4 - def convert_int32_into_int4_shape6(a_int32): - """ convert int32 values into int4 - Parameters - ---------- - a_int32 : int - - Return - ------ - a_int4 : int - """ - M, N, I, J, K, L = a_int32.shape - a_int4 = np.zeros(shape=(M, N, I, J, K, L // 8), dtype=np.int32) - for m in range(M): - for n in range (N): - for i in range(I): - for j in range(J): - for k in range(K): - for l in range(L // 8): - for a in range(8): - a_int4[m,n,i,j,k,l] = a_int4[m,n,i,j,k,l] | ((a_int32[m,n,i,j,k,l * 8 + a] & 0xf) << ((7 - a) * 4)) - return a_int4 a_np, w_np, b_np, c_np = get_ref_data() @@ -172,16 +148,16 @@ def convert_int32_into_int4_shape6(a_int32): # in_width, # in_channel // wmma_k, # wmma_k)).transpose((0,2,3,4,1,5)) - w_np = w_np.reshape((kernel, - kernel, - in_channel // wmma_k, - wmma_k, - num_filter // wmma_n, - wmma_n)).transpose((0,1,2,4,5,3)) + # w_np = w_np.reshape((kernel, + # kernel, + # in_channel // wmma_k, + # wmma_k, + # num_filter // wmma_n, + # wmma_n)).transpose((0,1,2,4,5,3)) if dtype == 'int4': a_np = convert_int32_into_int4(a_np) # b_np = convert_int32_into_int4(b_np) - w_np = convert_int32_into_int4_shape6(w_np) + w_np = convert_int32_into_int4(w_np) def check_device(device): ctx = tvm.context(device, 0) @@ -224,50 +200,50 @@ def check_device(device): print('Time cost of this operator: %f ms' % (evaluator(a, w, c).mean * 1000)) rtol = 1e-3 - # print(c.asnumpy().sum(), c_np.sum()) + print(c.asnumpy().sum(), c_np.sum()) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol) # # #Tuning the performance - # import logging, sys - # logging.getLogger('autotvm').setLevel(logging.DEBUG) - # logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) - - # log_filename = "conv2d_int8_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d.log" % (batch, in_channel, in_size, num_filter, kernel, stride, - # padding, dilation) - # tmp_log_file = log_filename + '.temp' - # num_trial = 1000 - # task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, - # padding, dilation) - # task = autotvm.create('conv2d_nhwc_tensorcore_int4.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) - # print(task.config_space) + import logging, sys + logging.getLogger('autotvm').setLevel(logging.DEBUG) + logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) + + log_filename = "conv2d_int4_nhwc_tensorcore_kernel_shape_%d_%d_%d_%d_%d_%d_%d_%d.log" % (batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation) + tmp_log_file = log_filename + '.temp' + num_trial = 1000 + task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation) + task = autotvm.create('conv2d_nhwc_tensorcore_int4.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) + print(task.config_space) - # measure_option = autotvm.measure_option( - # builder='local', - # runner=autotvm.LocalRunner(number=5)) - - # tuner = autotvm.tuner.XGBTuner(task) - # num_trial = min(num_trial, len(task.config_space)) - # with tvm.target.build_config(): - # tuner.tune(n_trial=num_trial, - # measure_option=measure_option, - # callbacks=[autotvm.callback.progress_bar(num_trial, prefix=task_name), - # autotvm.callback.log_to_file(tmp_log_file)]) - - # dispatch_context = autotvm.apply_history_best(tmp_log_file) - # best_config = dispatch_context.query(task.target, task.workload) - # print("\nBest config:") - # print(best_config) - - # #pick the best record to a cache file - # autotvm.record.pick_best(tmp_log_file, log_filename) - # os.remove(tmp_log_file) + measure_option = autotvm.measure_option( + builder='local', + runner=autotvm.LocalRunner(number=5)) + + tuner = autotvm.tuner.XGBTuner(task, feature_type='knob') + num_trial = min(num_trial, len(task.config_space)) + with tvm.target.build_config(): + tuner.tune(n_trial=num_trial, + measure_option=measure_option, + callbacks=[autotvm.callback.progress_bar(num_trial, prefix=task_name), + autotvm.callback.log_to_file(tmp_log_file)]) + + dispatch_context = autotvm.apply_history_best(tmp_log_file) + best_config = dispatch_context.query(task.target, task.workload) + print("\nBest config:") + print(best_config) + + #pick the best record to a cache file + autotvm.record.pick_best(tmp_log_file, log_filename) + os.remove(tmp_log_file) - # with autotvm.apply_graph_best(log_filename): - # with tvm.target.create(device): - # func = tvm.build(s, [A, W, C], device, name="conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, - # padding, dilation)) - # evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) - # print('Time cost of this operator after tuning: %f ms' % (evaluator(a, w, c).mean * 1000)) + with autotvm.apply_graph_best(log_filename): + with tvm.target.create(device): + func = tvm.build(s, [A, W, C], device, name="conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation)) + evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) + print('Time cost of this operator after tuning: %f ms' % (evaluator(a, w, c).mean * 1000)) check_device(devices) @@ -286,31 +262,31 @@ def test_conv2d_nhwc_tensorcore(): # verify_conv2d_nhwc(64, 256, 14, 512, 1, 2, 0) # verify_conv2d_nhwc(64, 512, 7, 512, 3, 1, 1) - # verify_conv2d_nhwc(32, 64, 56, 64, 3, 1, 1) - # verify_conv2d_nhwc(32, 64, 56, 64, 1, 1, 0) - # verify_conv2d_nhwc(32, 64, 56, 128, 3, 2, 1) - # verify_conv2d_nhwc(32, 64, 56, 64, 1, 2, 0) - # verify_conv2d_nhwc(32, 128, 28, 128, 3, 1, 1) - # verify_conv2d_nhwc(32, 128, 28, 256, 3, 2, 1) - # verify_conv2d_nhwc(32, 128, 28, 256, 1, 2, 0) - # verify_conv2d_nhwc(32, 256, 14, 256, 3, 1, 1) - # verify_conv2d_nhwc(32, 256, 14, 512, 3, 2, 1) - # verify_conv2d_nhwc(32, 256, 14, 512, 1, 2, 0) - # verify_conv2d_nhwc(32, 512, 7, 512, 3, 1, 1) - - # verify_conv2d_nhwc(16, 64, 56, 64, 3, 1, 1) - # verify_conv2d_nhwc(16, 64, 56, 64, 1, 1, 0) - # verify_conv2d_nhwc(16, 64, 56, 128, 3, 2, 1) - # verify_conv2d_nhwc(16, 64, 56, 64, 1, 2, 0) - # verify_conv2d_nhwc(16, 128, 28, 128, 3, 1, 1) - # verify_conv2d_nhwc(16, 128, 28, 256, 3, 2, 1) - # verify_conv2d_nhwc(16, 128, 28, 256, 1, 2, 0) - # verify_conv2d_nhwc(16, 256, 14, 256, 3, 1, 1) - # verify_conv2d_nhwc(16, 256, 14, 512, 3, 2, 1) - # verify_conv2d_nhwc(16, 256, 14, 512, 1, 2, 0) - # verify_conv2d_nhwc(16, 512, 7, 512, 3, 1, 1) - - verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) + verify_conv2d_nhwc(32, 64, 56, 64, 3, 1, 1) + verify_conv2d_nhwc(32, 64, 56, 64, 1, 1, 0) + verify_conv2d_nhwc(32, 64, 56, 128, 3, 2, 1) + verify_conv2d_nhwc(32, 64, 56, 64, 1, 2, 0) + verify_conv2d_nhwc(32, 128, 28, 128, 3, 1, 1) + verify_conv2d_nhwc(32, 128, 28, 256, 3, 2, 1) + verify_conv2d_nhwc(32, 128, 28, 256, 1, 2, 0) + verify_conv2d_nhwc(32, 256, 14, 256, 3, 1, 1) + verify_conv2d_nhwc(32, 256, 14, 512, 3, 2, 1) + verify_conv2d_nhwc(32, 256, 14, 512, 1, 2, 0) + verify_conv2d_nhwc(32, 512, 7, 512, 3, 1, 1) + + verify_conv2d_nhwc(16, 64, 56, 64, 3, 1, 1) + verify_conv2d_nhwc(16, 64, 56, 64, 1, 1, 0) + verify_conv2d_nhwc(16, 64, 56, 128, 3, 2, 1) + verify_conv2d_nhwc(16, 64, 56, 64, 1, 2, 0) + verify_conv2d_nhwc(16, 128, 28, 128, 3, 1, 1) + verify_conv2d_nhwc(16, 128, 28, 256, 3, 2, 1) + verify_conv2d_nhwc(16, 128, 28, 256, 1, 2, 0) + verify_conv2d_nhwc(16, 256, 14, 256, 3, 1, 1) + verify_conv2d_nhwc(16, 256, 14, 512, 3, 2, 1) + verify_conv2d_nhwc(16, 256, 14, 512, 1, 2, 0) + verify_conv2d_nhwc(16, 512, 7, 512, 3, 1, 1) + + # verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) # verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) # verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) # verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) From 112f948618d9b78f2d8c6f20fc5475d2eff025ad Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Thu, 14 May 2020 14:03:21 +0000 Subject: [PATCH 08/21] update intrinsic --- .../topi/cuda/conv2d_nhwc_tensorcore_int4.py | 93 ++++++++---- .../test_topi_conv2d_nhwc_tensorcore_int4.py | 136 +++++++++--------- 2 files changed, 132 insertions(+), 97 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py index 62d07af02a30..30fbdd696feb 100644 --- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py +++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py @@ -37,8 +37,8 @@ def intrin_wmma_load_matrix(scope, m, n, l, in_dtype): else: A = tvm.te.placeholder((n, l), name='A', dtype=in_dtype) C = tvm.te.compute((n, l), lambda i, j: A[i, j], name='C') - BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) - BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) + BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=8) + BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=8) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() @@ -46,11 +46,11 @@ def intrin_func(ins, outs): BC = outs[0] if scope == "wmma.matrix_a": ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', - BC.data, n, m, l, BC.elem_offset // 256, + BC.data, m, n, l, BC.elem_offset // (m * l), BA.access_ptr('r'), l, 'row_major')) elif scope == "wmma.matrix_b": ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', - BC.data, n, m, l, BC.elem_offset // 256, + BC.data, m, n, l, BC.elem_offset // (n * l), BA.access_ptr('r'), l, 'col_major')) return ib.get() @@ -65,9 +65,9 @@ def intrin_wmma_gemm(m, n, l, in_dtype): lambda ii, jj: te.sum(A[ii, k].astype('int32') * B[jj, k].astype('int32'), axis=k), name='C') - BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) - BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) - BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=64) + BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=8) + BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=8) + BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=8) def intrin_func(ins, outs): BA, BB = ins @@ -75,16 +75,16 @@ def intrin_func(ins, outs): def init(): ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, m, l, BC.elem_offset // n * n, 0.0)) + ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, m, n, l, BC.elem_offset, 0.0)) return ib.get() def update(): ib = tvm.tir.ir_builder.create() ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', - BC.data, BC.elem_offset // 64, - BA.data, BA.elem_offset // 256, - BB.data, BB.elem_offset // 256, - BC.data, BC.elem_offset // 64)) + BC.data, BC.elem_offset // (m * n), + BA.data, BA.elem_offset // (m * l), + BB.data, BB.elem_offset // (n * l), + BC.data, BC.elem_offset // (m * n))) return ib.get() return update(), init(), update() @@ -94,16 +94,16 @@ def update(): def intrin_wmma_store_matrix(m, n, l): A = te.placeholder((m, n), name='A', dtype='int32') - BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=64) + BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=8) C = te.compute((m, n), lambda i, j: A[i, j], name='C') - BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=64) + BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=8) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() BA = ins[0] BC = outs[0] ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', - BA.data, n, m, l, BA.elem_offset // 64, + BA.data, m, n, l, BA.elem_offset // (m * n), BC.access_ptr('w'), n, 'row_major')) return ib.get() @@ -128,12 +128,12 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype wmma_n = wmma_m = 8 wmma_k = 32 else: - wmma_m = 16 - wmma_n = 16 + wmma_m = 8 + wmma_n = 32 wmma_k = 16 batch, in_height, in_width, in_channels= get_const_tuple(Input.shape) - if in_dtype == 'int4': + if in_dtype == 'int4' or in_dtype == 'int8': kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape) else: kernel_h, kernel_w, _, num_filter, _, _ = get_const_tuple(Filter.shape) @@ -265,15 +265,15 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): vector_as = cfg["vector_as"].val split_block_k = cfg["split_block_k"].val inline_pad = cfg["inline_pad"].val - # block_row_warps = 1 - # block_col_warps = 2 - # warp_row_tiles = 32 - # warp_col_tiles = 16 - # chunk = 8 - # vector_ws = 1 - # inline_pad = 1 - # vector_as = 1 - # split_block_k = 1 + block_row_warps = 1 + block_col_warps = 1 + warp_row_tiles = 8 + warp_col_tiles = 4 + chunk = 2 + vector_ws = 1 + inline_pad = 0 + vector_as = 16 + split_block_k = 1 # inline_pad = 0 @@ -304,7 +304,7 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): # elif (batch % 32 == 0 and out_channels % 8 == 0): # cfg.define_knob("wmma_m", [32, 16, 8]) # wmma_m = cfg["wmma_m"].val - wmma_m = 16 + wmma_m = 8 wmma_k = 16 if wmma_m == 16: wmma_n = 16 @@ -353,7 +353,7 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): s[AS].vectorize(_t) # Schedule for W's share memory - s[WS].compute_at(s[ConvF], kh) + s[WS].compute_at(s[ConvF], kw) kh, kw, ic, o, ii, oo = WS.op.axis tx, xo = s[WS].split(o, nparts=block_row_warps) ty, yo = s[WS].split(xo, nparts=block_col_warps) @@ -370,6 +370,41 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): s[Conv].tensorize(nnc, intrin_wmma_store_matrix(wmma_m, wmma_n, wmma_k)) s[ConvF].tensorize(nnf, intrin_wmma_gemm(wmma_m, wmma_n, wmma_k, in_dtype)) + # shape = (wmma_m, wmma_n, wmma_k) + + # AS_shape = (wmma_m, wmma_k) + # AL_shape = (wmma_m, wmma_k) + # WS_shape = (wmma_n, wmma_k) + # WL_shape = (wmma_n, wmma_k) + # # else: + # # WS_shape = (wmma_k, wmma_n) + # # WL_shape = (wmma_k, wmma_n) + # CL_shape = (wmma_m, wmma_n) + # CS_shape = (wmma_m, wmma_n) + + # AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype) + # WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype) + # k_gemm = te.reduce_axis((0, wmma_k), name="k") + # CL_compute = te.compute(CL_shape, lambda ii, jj: + # te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) * \ + # WL_gemm[jj, k_gemm].astype(out_dtype), axis=k_gemm), + # name='C') + # AL_strides = [wmma_k, 1] + # AS_strides = [wmma_k, 1] + # WL_strides = [wmma_k, 1] + # WS_strides = [wmma_k, 1] + # CL_strides = [wmma_n, 1] + # CS_strides = [wmma_n, 1] + + # s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, + # "row_major", AS_shape, AL_shape, in_dtype)) + # s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, + # "col_major", WS_shape, WL_shape, in_dtype)) + # s[Conv].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, + # shape, out_dtype, CL_shape, CS_shape)) + # s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, + # WL_strides, CL_strides, shape)) + N, OH, OW, CO, nn, mm = get_const_tuple(Conv.shape) if in_dtype == 'int4': diff --git a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py index 2781a853725b..8c22842715c4 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py @@ -65,14 +65,14 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) # choose dtype from int4, int8 and float16 - dtype = 'int4' + dtype = 'int8' out_dtype = 'int32' if dtype == 'int4': wmma_n = wmma_m = 8 wmma_k = 32 else: - wmma_m = 16 - wmma_n = 16 + wmma_m = 32 + wmma_n = 8 wmma_k = 16 in_height = in_width = in_size @@ -109,7 +109,7 @@ def get_ref_data(): a_np = np.random.randint(low=1, high=7, size=a_shape).astype(dtype) w_np = np.random.randint(low=1, high=7, size=w_shape).astype(dtype) b_np = np.random.randint(low=1, high=7, size=bias_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)).transpose((0, 1, 3, 2)) c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) if add_bias: @@ -200,50 +200,50 @@ def check_device(device): print('Time cost of this operator: %f ms' % (evaluator(a, w, c).mean * 1000)) rtol = 1e-3 - print(c.asnumpy().sum(), c_np.sum()) + # print(c.asnumpy().sum(), c_np.sum()) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol) # # #Tuning the performance - import logging, sys - logging.getLogger('autotvm').setLevel(logging.DEBUG) - logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) - - log_filename = "conv2d_int4_nhwc_tensorcore_kernel_shape_%d_%d_%d_%d_%d_%d_%d_%d.log" % (batch, in_channel, in_size, num_filter, kernel, stride, - padding, dilation) - tmp_log_file = log_filename + '.temp' - num_trial = 1000 - task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, - padding, dilation) - task = autotvm.create('conv2d_nhwc_tensorcore_int4.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) - print(task.config_space) + # import logging, sys + # logging.getLogger('autotvm').setLevel(logging.DEBUG) + # logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) + + # log_filename = "conv2d_" + dtype +"_nhwc_tensorcore_kernel_shape_%d_%d_%d_%d_%d_%d_%d_%d.log" % (batch, in_channel, in_size, num_filter, kernel, stride, + # padding, dilation) + # tmp_log_file = log_filename + '.temp' + # num_trial = 1000 + # task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, + # padding, dilation) + # task = autotvm.create('conv2d_nhwc_tensorcore_int4.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) + # print(task.config_space) - measure_option = autotvm.measure_option( - builder='local', - runner=autotvm.LocalRunner(number=5)) - - tuner = autotvm.tuner.XGBTuner(task, feature_type='knob') - num_trial = min(num_trial, len(task.config_space)) - with tvm.target.build_config(): - tuner.tune(n_trial=num_trial, - measure_option=measure_option, - callbacks=[autotvm.callback.progress_bar(num_trial, prefix=task_name), - autotvm.callback.log_to_file(tmp_log_file)]) - - dispatch_context = autotvm.apply_history_best(tmp_log_file) - best_config = dispatch_context.query(task.target, task.workload) - print("\nBest config:") - print(best_config) - - #pick the best record to a cache file - autotvm.record.pick_best(tmp_log_file, log_filename) - os.remove(tmp_log_file) + # measure_option = autotvm.measure_option( + # builder='local', + # runner=autotvm.LocalRunner(number=5)) + + # tuner = autotvm.tuner.XGBTuner(task, feature_type='knob') + # num_trial = min(num_trial, len(task.config_space)) + # with tvm.target.build_config(): + # tuner.tune(n_trial=num_trial, + # measure_option=measure_option, + # callbacks=[autotvm.callback.progress_bar(num_trial, prefix=task_name), + # autotvm.callback.log_to_file(tmp_log_file)]) + + # dispatch_context = autotvm.apply_history_best(tmp_log_file) + # best_config = dispatch_context.query(task.target, task.workload) + # print("\nBest config:") + # print(best_config) + + # #pick the best record to a cache file + # autotvm.record.pick_best(tmp_log_file, log_filename) + # os.remove(tmp_log_file) - with autotvm.apply_graph_best(log_filename): - with tvm.target.create(device): - func = tvm.build(s, [A, W, C], device, name="conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, - padding, dilation)) - evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) - print('Time cost of this operator after tuning: %f ms' % (evaluator(a, w, c).mean * 1000)) + # with autotvm.apply_graph_best(log_filename): + # with tvm.target.create(device): + # func = tvm.build(s, [A, W, C], device, name="conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, + # padding, dilation)) + # evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) + # print('Time cost of this operator after tuning: %f ms' % (evaluator(a, w, c).mean * 1000)) check_device(devices) @@ -262,31 +262,31 @@ def test_conv2d_nhwc_tensorcore(): # verify_conv2d_nhwc(64, 256, 14, 512, 1, 2, 0) # verify_conv2d_nhwc(64, 512, 7, 512, 3, 1, 1) - verify_conv2d_nhwc(32, 64, 56, 64, 3, 1, 1) - verify_conv2d_nhwc(32, 64, 56, 64, 1, 1, 0) - verify_conv2d_nhwc(32, 64, 56, 128, 3, 2, 1) - verify_conv2d_nhwc(32, 64, 56, 64, 1, 2, 0) - verify_conv2d_nhwc(32, 128, 28, 128, 3, 1, 1) - verify_conv2d_nhwc(32, 128, 28, 256, 3, 2, 1) - verify_conv2d_nhwc(32, 128, 28, 256, 1, 2, 0) - verify_conv2d_nhwc(32, 256, 14, 256, 3, 1, 1) - verify_conv2d_nhwc(32, 256, 14, 512, 3, 2, 1) - verify_conv2d_nhwc(32, 256, 14, 512, 1, 2, 0) - verify_conv2d_nhwc(32, 512, 7, 512, 3, 1, 1) - - verify_conv2d_nhwc(16, 64, 56, 64, 3, 1, 1) - verify_conv2d_nhwc(16, 64, 56, 64, 1, 1, 0) - verify_conv2d_nhwc(16, 64, 56, 128, 3, 2, 1) - verify_conv2d_nhwc(16, 64, 56, 64, 1, 2, 0) - verify_conv2d_nhwc(16, 128, 28, 128, 3, 1, 1) - verify_conv2d_nhwc(16, 128, 28, 256, 3, 2, 1) - verify_conv2d_nhwc(16, 128, 28, 256, 1, 2, 0) - verify_conv2d_nhwc(16, 256, 14, 256, 3, 1, 1) - verify_conv2d_nhwc(16, 256, 14, 512, 3, 2, 1) - verify_conv2d_nhwc(16, 256, 14, 512, 1, 2, 0) - verify_conv2d_nhwc(16, 512, 7, 512, 3, 1, 1) - - # verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) + # verify_conv2d_nhwc(32, 64, 56, 64, 3, 1, 1) + # verify_conv2d_nhwc(32, 64, 56, 64, 1, 1, 0) + # verify_conv2d_nhwc(32, 64, 56, 128, 3, 2, 1) + # verify_conv2d_nhwc(32, 64, 56, 64, 1, 2, 0) + # verify_conv2d_nhwc(32, 128, 28, 128, 3, 1, 1) + # verify_conv2d_nhwc(32, 128, 28, 256, 3, 2, 1) + # verify_conv2d_nhwc(32, 128, 28, 256, 1, 2, 0) + # verify_conv2d_nhwc(32, 256, 14, 256, 3, 1, 1) + # verify_conv2d_nhwc(32, 256, 14, 512, 3, 2, 1) + # verify_conv2d_nhwc(32, 256, 14, 512, 1, 2, 0) + # verify_conv2d_nhwc(32, 512, 7, 512, 3, 1, 1) + + # verify_conv2d_nhwc(16, 64, 56, 64, 3, 1, 1) + # verify_conv2d_nhwc(16, 64, 56, 64, 1, 1, 0) + # verify_conv2d_nhwc(16, 64, 56, 128, 3, 2, 1) + # verify_conv2d_nhwc(16, 64, 56, 64, 1, 2, 0) + # verify_conv2d_nhwc(16, 128, 28, 128, 3, 1, 1) + # verify_conv2d_nhwc(16, 128, 28, 256, 3, 2, 1) + # verify_conv2d_nhwc(16, 128, 28, 256, 1, 2, 0) + # verify_conv2d_nhwc(16, 256, 14, 256, 3, 1, 1) + # verify_conv2d_nhwc(16, 256, 14, 512, 3, 2, 1) + # verify_conv2d_nhwc(16, 256, 14, 512, 1, 2, 0) + # verify_conv2d_nhwc(16, 512, 7, 512, 3, 1, 1) + + verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) # verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) # verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) # verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) From a8e7d2121243de05d6c92b83d9424ea269c1d29e Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Thu, 14 May 2020 14:19:56 +0000 Subject: [PATCH 09/21] update intrinsic --- .../topi/cuda/conv2d_nhwc_tensorcore_int4.py | 174 +++++------------- topi/python/topi/cuda/tensor_intrin.py | 27 +-- 2 files changed, 55 insertions(+), 146 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py index 30fbdd696feb..c772de54b703 100644 --- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py +++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py @@ -25,89 +25,7 @@ from ..nn.pad import pad from ..nn.util import get_pad_tuple from topi.cuda.injective import schedule_injective_from_existing -# from .tensor_intrin import intrin_wmma_load_matrix_A -# from .tensor_intrin import intrin_wmma_load_matrix_W -# from .tensor_intrin import intrin_wmma_store_matrix -# from .tensor_intrin import intrin_wmma_gemm - -def intrin_wmma_load_matrix(scope, m, n, l, in_dtype): - if scope == 'wmma.matrix_a': - A = tvm.te.placeholder((m, l), name='A', dtype=in_dtype) - C = tvm.te.compute((m, l), lambda i, j: A[i, j], name='C') - else: - A = tvm.te.placeholder((n, l), name='A', dtype=in_dtype) - C = tvm.te.compute((n, l), lambda i, j: A[i, j], name='C') - BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=8) - BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=8) - def intrin_func(ins, outs): - ib = tvm.tir.ir_builder.create() - - BA = ins[0] - BC = outs[0] - if scope == "wmma.matrix_a": - ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', - BC.data, m, n, l, BC.elem_offset // (m * l), - BA.access_ptr('r'), l, 'row_major')) - elif scope == "wmma.matrix_b": - ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', - BC.data, m, n, l, BC.elem_offset // (n * l), - BA.access_ptr('r'), l, 'col_major')) - return ib.get() - - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) - - -def intrin_wmma_gemm(m, n, l, in_dtype): - A = te.placeholder((m, l), name='A', dtype=in_dtype) - B = te.placeholder((n, l), name='B', dtype=in_dtype) - k = te.reduce_axis((0, l), name="k") - C = te.compute((m, n), - lambda ii, jj: - te.sum(A[ii, k].astype('int32') * B[jj, k].astype('int32'), axis=k), - name='C') - BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=8) - BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=8) - BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=8) - - def intrin_func(ins, outs): - BA, BB = ins - BC, = outs - - def init(): - ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, m, n, l, BC.elem_offset, 0.0)) - return ib.get() - - def update(): - ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', - BC.data, BC.elem_offset // (m * n), - BA.data, BA.elem_offset // (m * l), - BB.data, BB.elem_offset // (n * l), - BC.data, BC.elem_offset // (m * n))) - return ib.get() - - return update(), init(), update() - - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) - - -def intrin_wmma_store_matrix(m, n, l): - A = te.placeholder((m, n), name='A', dtype='int32') - BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=8) - C = te.compute((m, n), lambda i, j: A[i, j], name='C') - BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=8) - - def intrin_func(ins, outs): - ib = tvm.tir.ir_builder.create() - BA = ins[0] - BC = outs[0] - ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', - BA.data, m, n, l, BA.elem_offset // (m * n), - BC.access_ptr('w'), n, 'row_major')) - return ib.get() - - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) +from .tensor_intrin import intrin_wmma_load_matrix_A, intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype, out_dtype): """Compute declaration for tensorcore""" @@ -137,16 +55,15 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape) else: kernel_h, kernel_w, _, num_filter, _, _ = get_const_tuple(Filter.shape) - # if in_dtype == 'int4': - # pass - # # assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0) - # else: - # assert (batch % 16 == 0 and in_channels % 16 == 0 and num_filter % 16 == 0) or \ - # (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0) or \ - # (batch % 32 == 0 and in_channels % 16 == 0 and num_filter % 8 == 0), \ - # "The shape of (batch, in_channels, num_filter) "\ - # "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for fp16 and int8, "\ - # "and (8, 32, 8) for int4" + if in_dtype == 'int4': + assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0) + else: + assert (batch % 16 == 0 and in_channels % 16 == 0 and num_filter % 16 == 0) or \ + (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0) or \ + (batch % 32 == 0 and in_channels % 16 == 0 and num_filter % 8 == 0), \ + "The shape of (batch, in_channels, num_filter) "\ + "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for fp16 and int8, "\ + "and (8, 32, 8) for int4" # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 @@ -211,7 +128,6 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): Conv = s[Out].op.input_tensors[0] ic, kh, kw, ii = s[Conv].op.reduce_axis out_dtype = Conv.dtype - # trans_paddata, kernel = s[Conv].op.input_tensors Apad, kernel = s[Conv].op.input_tensors A_transpose = s[Apad].op.input_tensors[0] @@ -365,45 +281,37 @@ def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): s[WS].bind(to, thread_x) s[WS].vectorize(ti) - s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a', wmma_m, wmma_n, wmma_k, in_dtype)) - s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b', wmma_m, wmma_n, wmma_k, in_dtype)) - s[Conv].tensorize(nnc, intrin_wmma_store_matrix(wmma_m, wmma_n, wmma_k)) - s[ConvF].tensorize(nnf, intrin_wmma_gemm(wmma_m, wmma_n, wmma_k, in_dtype)) - - # shape = (wmma_m, wmma_n, wmma_k) - - # AS_shape = (wmma_m, wmma_k) - # AL_shape = (wmma_m, wmma_k) - # WS_shape = (wmma_n, wmma_k) - # WL_shape = (wmma_n, wmma_k) - # # else: - # # WS_shape = (wmma_k, wmma_n) - # # WL_shape = (wmma_k, wmma_n) - # CL_shape = (wmma_m, wmma_n) - # CS_shape = (wmma_m, wmma_n) - - # AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype) - # WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype) - # k_gemm = te.reduce_axis((0, wmma_k), name="k") - # CL_compute = te.compute(CL_shape, lambda ii, jj: - # te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) * \ - # WL_gemm[jj, k_gemm].astype(out_dtype), axis=k_gemm), - # name='C') - # AL_strides = [wmma_k, 1] - # AS_strides = [wmma_k, 1] - # WL_strides = [wmma_k, 1] - # WS_strides = [wmma_k, 1] - # CL_strides = [wmma_n, 1] - # CS_strides = [wmma_n, 1] - - # s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, - # "row_major", AS_shape, AL_shape, in_dtype)) - # s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, - # "col_major", WS_shape, WL_shape, in_dtype)) - # s[Conv].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, - # shape, out_dtype, CL_shape, CS_shape)) - # s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, - # WL_strides, CL_strides, shape)) + shape = (wmma_m, wmma_n, wmma_k) + + AS_shape = (wmma_m, wmma_k) + AL_shape = (wmma_m, wmma_k) + WS_shape = (wmma_n, wmma_k) + WL_shape = (wmma_n, wmma_k) + CL_shape = (wmma_m, wmma_n) + CS_shape = (wmma_m, wmma_n) + + AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype) + WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype) + k_gemm = te.reduce_axis((0, wmma_k), name="k") + CL_compute = te.compute(CL_shape, lambda ii, jj: + te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) * \ + WL_gemm[jj, k_gemm].astype(out_dtype), axis=k_gemm), + name='C') + AL_strides = [wmma_k, 1] + AS_strides = [wmma_k, 1] + WL_strides = [wmma_k, 1] + WS_strides = [wmma_k, 1] + CL_strides = [wmma_n, 1] + CS_strides = [wmma_n, 1] + + s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, + "row_major", AS_shape, AL_shape, in_dtype)) + s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, + "col_major", WS_shape, WL_shape, in_dtype)) + s[Conv].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, + shape, out_dtype, CL_shape, CS_shape)) + s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, + WL_strides, CL_strides, shape)) N, OH, OW, CO, nn, mm = get_const_tuple(Conv.shape) diff --git a/topi/python/topi/cuda/tensor_intrin.py b/topi/python/topi/cuda/tensor_intrin.py index fcb1458984e9..c0540aa42c20 100644 --- a/topi/python/topi/cuda/tensor_intrin.py +++ b/topi/python/topi/cuda/tensor_intrin.py @@ -85,11 +85,11 @@ def intrin_wmma_load_matrix_A(strides_dst, strides_from, shape, layout, A_shape, A = te.placeholder(A_shape, name='A', dtype=in_dtype) BA = tvm.tir.decl_buffer(A.shape, A.dtype, - scope='shared', strides=[te.var("s1"), te.var("s2"), te.var("s3"), te.var("s4")], + scope='shared', strides=strides_from, data_alignment=32, offset_factor=8) C = te.compute(C_shape, lambda *i: A(*i), name='C') BC = tvm.tir.decl_buffer(C.shape, C.dtype, - scope="wmma.matrix_a", strides=[te.var("s1"), te.var("s2"), te.var("s3"), te.var("s4")], + scope="wmma.matrix_a", strides=strides_dst, data_alignment=32, offset_factor=8) def intrin_func(ins, outs): @@ -110,13 +110,14 @@ def intrin_func(ins, outs): def intrin_wmma_load_matrix_W(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype): """Intrin function for loading data from shared memory to wmma.matrix_b""" wmma_m, wmma_n, wmma_k = shape + A = te.placeholder(A_shape, name='A', dtype=in_dtype) BA = tvm.tir.decl_buffer(A.shape, A.dtype, - scope='shared', strides=[te.var("s1"), te.var("s2")], + scope='shared', strides=strides_from, data_alignment=32, offset_factor=8) C = te.compute(C_shape, lambda *i: A(*i), name='C') BC = tvm.tir.decl_buffer(C.shape, C.dtype, - scope="wmma.matrix_b", strides=[te.var("s3"), te.var("s4")], + scope="wmma.matrix_b", strides=strides_dst, data_alignment=32, offset_factor=8) def intrin_func(ins, outs): @@ -144,7 +145,7 @@ def intrin_wmma_store_matrix(strides_dst, strides_from, shape, out_dtype, A_shap offset_factor=8) C = te.compute(C_shape, lambda *i: A(*i), name='C') BC = tvm.tir.decl_buffer(C.shape, C.dtype, - scope='shared', strides=strides_dst, + scope='global', strides=strides_dst, data_alignment=32, offset_factor=8) def intrin_func(ins, outs): @@ -182,25 +183,25 @@ def intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, strides_A, BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, - offset_factor=8, strides=[te.var("s1"), te.var("s2"), te.var("s3"), te.var("s4")]) + offset_factor=8, strides=strides_A) BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, - offset_factor=8, strides=[te.var("s1"), te.var("s2")]) + offset_factor=8, strides=strides_W) BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, - offset_factor=8, strides=[te.var("s1"), te.var("s2"), te.var("s3"), te.var("s4")]) + offset_factor=8, strides=strides_Conv) def intrin_func(ins, outs): BA, BB = ins BC, = outs - def warp_index(offset, row, col): + def warp_idnex(offset, row, col): row = row * col return offset // row + offset % row // col - warp_index_A = warp_index(BA.elem_offset, wmma_m, wmma_k) - warp_index_B = warp_index(BB.elem_offset, wmma_k, wmma_n) - warp_index_C = warp_index(BC.elem_offset, wmma_m, wmma_n) + warp_index_A = warp_idnex(BA.elem_offset, wmma_m, wmma_k) + warp_index_B = warp_idnex(BB.elem_offset, wmma_k, wmma_n) + warp_index_C = warp_idnex(BC.elem_offset, wmma_m, wmma_n) def init(): ib = tvm.tir.ir_builder.create() @@ -220,4 +221,4 @@ def update(): return update(), init(), update() - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) \ No newline at end of file From 37c690c6b6217031337d1e92333accfcc9ead185 Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Wed, 22 Jul 2020 14:41:53 +0000 Subject: [PATCH 10/21] support int4/int8 hwnc layout --- python/tvm/relay/op/strategy/cuda.py | 15 + topi/python/topi/cuda/__init__.py | 1 + topi/python/topi/cuda/conv2d_alter_op.py | 30 ++ .../topi/cuda/conv2d_hwnc_tensorcore.py | 405 ++++++++++++++++++ .../test_topi_conv2d_hwnc_tensorcore.py | 136 ++++++ 5 files changed, 587 insertions(+) create mode 100644 topi/python/topi/cuda/conv2d_hwnc_tensorcore.py create mode 100644 topi/tests/python/test_topi_conv2d_hwnc_tensorcore.py diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 9189b5edff83..5c31d0532375 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -148,6 +148,21 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore), name="conv2d_nhwc_tensorcore.cuda", plevel=20) + elif layout == "HWNC": + assert kernel_layout in ["HWOI", "HWOI16o16i", "HWOI8o32i", "HWOI32o16i"] + _, _, N, in_channels = get_const_tuple(data.shape) + pre_computed = len(kernel.shape) == 6 + if pre_computed: + _, _, oc_chunk, _, oc_block_factor, _ = get_const_tuple(kernel.shape) + out_channels = oc_chunk * oc_block_factor + else: + _, _, out_channels, _ = get_const_tuple(kernel.shape) + if topi.cuda.is_shape_tensorcore_direct_qualified(batch=N, in_channels=in_channels, num_filter=out_channels, in_dtype=data.dtype): + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_hwnc_tensorcore), + wrap_topi_schedule(topi.cuda.schedule_conv2d_hwnc_tensorcore), + name="conv2d_hwnc_tensorcore_direct.cuda", + plevel=20) elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: assert kernel_layout == "OIHW4o4i" strategy.add_implementation( diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index ec2165dac6dd..f75e72cff7a4 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -49,3 +49,4 @@ from .conv2d_nhwc_tensorcore_int4 import * from .conv3d_ndhwc_tensorcore import * from .dense_tensorcore import * +from .conv2d_hwnc_tensorcore import * \ No newline at end of file diff --git a/topi/python/topi/cuda/conv2d_alter_op.py b/topi/python/topi/cuda/conv2d_alter_op.py index 8d9e86c192a0..16b348a30c8c 100644 --- a/topi/python/topi/cuda/conv2d_alter_op.py +++ b/topi/python/topi/cuda/conv2d_alter_op.py @@ -134,6 +134,36 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): "group_conv2d_NCHWc_int8.cuda") dispatch_ctx.update(target, new_workload, cfg) return relay.nn.conv2d(*inputs, **new_attrs) + + if topi_tmpl == "conv2d_HWNCnc_tensorcore.cuda": + assert data_layout == "HWNC" and kernel_layout == "HWOI" + + H, W, N, CI = get_const_tuple(data.shape) + KH, KW, CO, _ = get_const_tuple(kernel.shape) + + if kernel.dtype in ['int4', 'uint4'] and (CI % 32 != 0 or CO % 8 != 0) or \ + kernel.dtype in ['int8', 'uint8'] and (CI % 16 != 0 or CO % 32 != 0): + return relay.nn.conv2d(*inputs, **new_attrs) + + new_attrs["channels"] = CO + if kernel.dtype in ['int4', 'uint4']: + new_attrs['kernel_layout'] = 'HWOI8o32i' + ic_block_factor = 32 + oc_block_factor = 8 + else: + new_attrs['kernel_layout']= 'HWOI32o16i' + ic_block_factor = 16 + oc_block_factor = 32 + + new_kernel = te.placeholder((KH, KW, CO // oc_block_factor, CI // ic_block_factor, + oc_block_factor, ic_block_factor), dtype=kernel.dtype) + + new_workload = autotvm.task.args_to_workload( + [data, new_kernel, strides, padding, dilation, out_dtype], + "conv2d_HWNCnc_tensorcore.cuda") + + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.conv2d(*inputs, **new_attrs) return None diff --git a/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py b/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py new file mode 100644 index 000000000000..621ed343413f --- /dev/null +++ b/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py @@ -0,0 +1,405 @@ +import numpy as np +import tvm +from tvm import te +from tvm import autotvm +from ..util import get_const_tuple, traverse_inline, simplify, tag +from ..nn.pad import pad +from ..nn.util import get_pad_tuple +from topi.cuda.injective import schedule_injective_from_existing +from .tensor_intrin import intrin_wmma_load_matrix_A, intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm + +def unpack_HWNCnc_to_hwnc(packed_out, out_dtype): + """Unpack conv2d_hwnc output from layout hwncnc to hwnc + + Parameters + ----------- + packed_out : tvm.te.Tensor + The output tensor of conv2d_hwnc. + + out_dtype : str + The output dtype. + + Returns + ------- + unpacked_out : tvm.te.Tensor + The unpacked output tensor in hwnc layout. + """ + H, W, N, O, wmma_m, wmma_n = get_const_tuple(packed_out.shape) + + idxmod = tvm.tir.indexmod + idxdiv = tvm.tir.indexdiv + + oshape = (H, W, N * wmma_m, O * wmma_n) + unpacked_out = \ + te.compute(oshape, + lambda h, w, n, o: + packed_out[h, w, idxdiv(n, wmma_m), idxdiv(o, wmma_n), idxmod(n, wmma_m), idxmod(o, wmma_n)] + .astype(out_dtype), + name='output_unpack', + tag=tag.INJECTIVE+",unpack_hwncc") + return unpacked_out + +def conv2d_hwnc_tensorcore(data, kernel, strides, padding, dilation, in_dtype, out_dtype='int32'): + """Compute conv2d internally using conv2d_nchwc layout for int8 dtype""" + assert data.dtype in ('int4', 'uint4', 'int8', 'uint8') + assert kernel.dtype in ('int4', 'uint4', 'int8', 'uint8') + # assert data.dtype == kernel.dtype + packed_out = hwnc_tensorcore_cuda(data, kernel, strides, padding, dilation, out_dtype) + return unpack_HWNCnc_to_hwnc(packed_out, out_dtype) + +@autotvm.register_topi_compute("conv2d_HWNCnc_tensorcore.cuda") +def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtype='int32'): + """Compute declaration for tensorcore""" + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + in_dtype = Input.dtype + + if in_dtype in ['int4', 'uint4']: + wmma_n = wmma_m = 8 + wmma_k = 32 + else: + wmma_m = 8 + wmma_n = 32 + wmma_k = 16 + + pre_computed = len(Filter.shape) == 6 + in_height, in_width, batch, in_channels = get_const_tuple(Input.shape) + if pre_computed: + kernel_h, kernel_w, oc_chunk, ic_chunk, oc_block_factor, ic_block_factor = get_const_tuple(Filter.shape) + num_filter = oc_block_factor * oc_chunk + else: + kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape) + + if in_dtype in ['int4', 'uint4']: + assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0) + else: + assert (batch % 16 == 0 and in_channels % 16 == 0 and num_filter % 16 == 0) or \ + (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0) or \ + (batch % 32 == 0 and in_channels % 16 == 0 and num_filter % 8 == 0), \ + "The shape of (batch, in_channels, num_filter) "\ + "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for fp16 and int8, "\ + "and (8, 32, 8) for int4" + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + + out_channels = num_filter + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + + cfg.add_flop(2 * batch * out_height * out_width * out_channels * in_channels * kernel_h * kernel_w) + + # Input feature map: (H, W, N, IC, n, ic) + data_shape = (in_height, + in_width, + batch // wmma_m, + in_channels // wmma_k, + wmma_m, + wmma_k) + + # Kernel: (H, W, OC, IC, ic, oc) + kernel_shape = (kernel_h, + kernel_w, + out_channels // wmma_n, + in_channels // wmma_k, + wmma_n, + wmma_k) + + # Reduction axes + kh = te.reduce_axis((0, kernel_h), name='kh') + kw = te.reduce_axis((0, kernel_w), name='kw') + ic = te.reduce_axis((0, in_channels // wmma_k), name='ic') + ii = te.reduce_axis((0, wmma_k), name='ii') + + if pre_computed: + packed_kernel = Filter + else: + packed_kernel = te.compute(kernel_shape, + lambda kh, kw, o, i, oo, ii: Filter[kh, kw, o * wmma_n + oo, i * wmma_k + ii], + name="packed_kernel" + ) + + packed_data = te.compute(data_shape, + lambda h, w, n, i, nn, ii: Input[h, w, n * wmma_m + nn, i * wmma_k + ii] + ) + + pad_before = [pad_top, pad_left, 0, 0, 0, 0] + pad_after = [pad_down, pad_right, 0, 0, 0, 0] + pad_data = pad(packed_data, pad_before, pad_after, name="pad_data") + + + Conv = te.compute((out_height, out_width, batch // wmma_m, out_channels // wmma_n,wmma_m, wmma_n), + lambda h, w, n, o, nn, oo: te.sum( + (pad_data[h * stride_h + kh, w * stride_w + kw, n, ic, nn, ii].astype('int32') * + packed_kernel[kh, kw, o, ic, oo, ii].astype('int32')), + axis=[ic, kh, kw, ii]), + name="Conv", tag="conv2d_HWNCnc_tensorcore") + return Conv + + +def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): + packed_data, packed_kernel = s[Conv].op.input_tensors + ic, kh, kw, ii = s[Conv].op.reduce_axis + pad_data = s[packed_data].op.input_tensors[0] + + block_x = te.thread_axis('blockIdx.x') + block_y = te.thread_axis('blockIdx.y') + block_z = te.thread_axis('blockIdx.z') + thread_x = te.thread_axis('threadIdx.x') + thread_y = te.thread_axis('threadIdx.y') + thread_z = te.thread_axis('threadIdx.z') + + # Designate the memory hierarchy + AS = s.cache_read(packed_data, 'shared', [Conv]) + WS = s.cache_read(packed_kernel, 'shared', [Conv]) + AF = s.cache_read(AS, 'wmma.matrix_a', [Conv]) + WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) + ConvF = s.cache_write(Conv, 'wmma.accumulator') + + if Conv.op in s.outputs: + output = Conv + ConvS = s.cache_read(ConvF, 'shared', [Conv]) + OL = ConvS + else: + output = s.outputs[0].output(0) + s[Conv].set_scope('shared') + OL = Conv + + out_dtype = Conv.dtype + + if isinstance(packed_kernel.op, te.tensor.ComputeOp) and packed_kernel.name == "packed_kernel": + if autotvm.GLOBAL_SCOPE.in_tuning: + s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region") + else: + with tvm.target.create('cuda'): + schedule_injective_from_existing(s, packed_kernel) + + if isinstance(pad_data.op, te.tensor.ComputeOp) and "pad" in pad_data.op.tag: + s[pad_data].compute_inline() + data = pad_data.op.input_tensors[0] + + if autotvm.GLOBAL_SCOPE.in_tuning: + # skip this part during tuning to make recrods accurate + # this part will be pre-computed during NNVM's pre-compute optimization pass + s[pad_data].pragma(s[pad_data].op.axis[0], "debug_skip_region") + else: + data = pad_data + s[data].compute_inline() + + data_dtype = data.dtype + kernel_dtype = packed_kernel.dtype + + # Schedule for autotvm + cfg.define_knob("block_row_warps", [1, 2, 4]) + cfg.define_knob("block_col_warps", [1, 2, 4]) + cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16]) + cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16]) + cfg.define_knob("chunk", [1, 2, 4, 8]) + cfg.define_knob("fuse_pack", [0, 1]) + cfg.define_knob("split_block_k_nums", [1, 2, 4, 8, 16, 32]) + cfg.define_knob("vector_ws", [1, 8]) + cfg.define_knob("vector_as", [1, 8, 16]) + + block_row_warps = cfg["block_row_warps"].val + block_col_warps = cfg["block_col_warps"].val + warp_row_tiles = cfg["warp_row_tiles"].val + warp_col_tiles = cfg["warp_col_tiles"].val + chunk = cfg["chunk"].val + vector_as = cfg["vector_as"].val + vector_ws = cfg["vector_ws"].val + split_block_k_nums = cfg["split_block_k_nums"].val + fuse_pack = cfg["fuse_pack"].val + + if not fuse_pack: + s[packed_data].compute_inline() + else: + with tvm.target.create('cuda'): + schedule_injective_from_existing(s, packed_data) + + if data_dtype in ['int4', 'uint4']: + wmma_m = wmma_n = 8 + wmma_k = 32 + else: + wmma_m = 8 + wmma_n = 32 + wmma_k = 16 + + warp_size = 32 + + # Schedule for output + if len(s[output].op.axis) == 4: + hc, wc, nc, oc, = output.op.axis + nc, nnc = s[output].split(nc, factor=wmma_m) + oc, ooc = s[output].split(oc, factor=wmma_n) + else: + hc, wc, nc, oc, nnc, ooc = output.op.axis + + kernel_scope, hc = s[output].split(hc, nparts=1) + + block_k = s[output].fuse(hc, wc) + block_k, split_block_k = s[output].split(block_k, factor=split_block_k_nums) + nc, nci = s[output].split(nc, factor=warp_row_tiles) + block_i, nc = s[output].split(nc, factor=block_row_warps) + oc, oci = s[output].split(oc, factor=warp_col_tiles) + block_j, oc = s[output].split(oc, factor=block_col_warps) + s[output].reorder(block_k, split_block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc) + t = s[output].fuse(nnc, ooc) + _, tx = s[output].split(t, factor=warp_size) + s[output].bind(block_k, block_z) + s[output].bind(block_i, block_x) + s[output].bind(block_j, block_y) + s[output].bind(tx, thread_x) + s[output].bind(nc, thread_y) + s[output].bind(oc, thread_z) + + # Schedule wmma store + s[OL].compute_at(s[output], block_j) + hc, wc, nc, oc, nnc, ooc = OL.op.axis + oc, oci = s[OL].split(oc, factor=warp_col_tiles) + _, oc = s[OL].split(oc, factor=block_col_warps) + nc, nci = s[OL].split(nc, factor=warp_row_tiles) + _, nc = s[OL].split(nc, factor=block_row_warps) + s[OL].reorder(nc, oc, nci, oci, nnc, ooc) + s[OL].bind(nc, thread_y) + s[OL].bind(oc, thread_z) + + # Schedule local computation + s[ConvF].compute_at(s[OL], oc) + h, w, n, o, nnf, oof = ConvF.op.axis + ko, ki = s[ConvF].split(ic, factor=chunk) + s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii) + + cfg.define_reorder("reorder_inner", [ko, kh], policy="all") + cfg["reorder_inner"].apply(s, ConvF, [ko, kh]) + cfg["reorder_inner"].apply(s, ConvF, [ki, kw]) + + cfg.define_knob("compute_at_AS", [0, 1, 2, 3]) + cfg.define_knob("compute_at_WS", [0, 1, 2, 3]) + compute_at_AS = cfg["compute_at_AS"].val + compute_at_WS = cfg["compute_at_WS"].val + + # Move intermediate computation into each output compute tile + s[AF].compute_at(s[ConvF], kw) + s[WF].compute_at(s[ConvF], kw) + + # Schedule for A's share memory + if compute_at_AS == 0: + s[AS].compute_at(s[ConvF], ki) + elif compute_at_AS == 1: + s[AS].compute_at(s[ConvF], kw) + elif compute_at_AS == 2: + s[AS].compute_at(s[ConvF], ko) + else: + s[AS].compute_at(s[ConvF], kh) + # s[AS].compute_at(s[ConvF], kh) + h, w, n, i, nn, ii = AS.op.axis + tx, xo = s[AS].split(n, nparts=block_row_warps) + ty, yo = s[AS].split(xo, nparts=block_col_warps) + t = s[AS].fuse(nn, ii) + to, ti = s[AS].split(t, nparts=warp_size) + ti, _t = s[AS].split(ti, factor=vector_as) + s[AS].bind(tx, thread_y) + s[AS].bind(ty, thread_z) + s[AS].bind(to, thread_x) + s[AS].vectorize(_t) + + # Schedule for W's share memory + if compute_at_WS == 0: + s[WS].compute_at(s[ConvF], ki) + elif compute_at_WS == 1: + s[WS].compute_at(s[ConvF], kw) + elif compute_at_WS == 2: + s[WS].compute_at(s[ConvF], ko) + else: + s[WS].compute_at(s[ConvF], kh) + s[WS].compute_at(s[ConvF], kw) + kh, kw, ic, o, ii, oo = WS.op.axis + tx, xo = s[WS].split(o, nparts=block_row_warps) + ty, yo = s[WS].split(xo, nparts=block_col_warps) + t = s[WS].fuse(ii, oo) + to, ti = s[WS].split(t, nparts=warp_size) + ti, _t = s[WS].split(ti, factor=vector_ws) + s[WS].bind(tx, thread_y) + s[WS].bind(ty, thread_z) + s[WS].bind(to, thread_x) + s[WS].vectorize(ti) + + # double buffer + cfg.define_knob('AS_double_buffer', [0, 1]) + cfg.define_knob('WS_double_buffer', [0, 1]) + if cfg['AS_double_buffer'].val: + s[AS].double_buffer() + if cfg['WS_double_buffer'].val: + s[WS].double_buffer() + + # unroll + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + s[output].pragma(kernel_scope, 'auto_unroll_max_step', + cfg['auto_unroll_max_step'].val) + s[output].pragma(kernel_scope, 'unroll_explicit', False) + + shape = (wmma_m, wmma_n, wmma_k) + + AS_shape = (wmma_m, wmma_k) + AL_shape = (wmma_m, wmma_k) + WS_shape = (wmma_n, wmma_k) + WL_shape = (wmma_n, wmma_k) + CL_shape = (wmma_m, wmma_n) + CS_shape = (wmma_m, wmma_n) + + AL_gemm = te.placeholder(AL_shape, name='A', dtype=data_dtype) + WL_gemm = te.placeholder(WL_shape, name='B', dtype=kernel_dtype) + k_gemm = te.reduce_axis((0, wmma_k), name="k") + CL_compute = te.compute(CL_shape, lambda ii, jj: + te.sum((AL_gemm[ii, k_gemm].astype('int32')* WL_gemm[jj, k_gemm].astype('int32')), axis=k_gemm), + name='C') + + AL_strides = [wmma_k, 1] + AS_strides = [wmma_k, 1] + WL_strides = [wmma_k, 1] + WS_strides = [wmma_k, 1] + CL_strides = [wmma_n, 1] + CS_strides = [wmma_n, 1] + + + s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, + "row_major", AS_shape, AL_shape, data_dtype)) + + s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, + "col_major", WS_shape, WL_shape, kernel_dtype)) + + s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, + shape, out_dtype, CL_shape, CS_shape)) + + + s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, + WL_strides, CL_strides, shape)) + + return s + +@autotvm.register_topi_schedule("conv2d_HWNCnc_tensorcore.cuda") +def schedule_conv2d_hwnc_tensorcore(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + def _callback(op): + if 'conv2d_HWNCnc_tensorcore' in op.tag: + schedule_hwnc_tensorcore_cuda(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s + diff --git a/topi/tests/python/test_topi_conv2d_hwnc_tensorcore.py b/topi/tests/python/test_topi_conv2d_hwnc_tensorcore.py new file mode 100644 index 000000000000..659138a40b20 --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_hwnc_tensorcore.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-arguments +"""Example code to do convolution.""" + +import numpy as np +import tvm +import os +import topi +import topi.testing +from tvm import te, autotvm +from tvm.contrib.pickle_memoize import memoize +from tvm.contrib import nvcc +from topi.nn.util import get_pad_tuple +from topi.util import get_const_tuple + +_conv2d_hwnc_tensorcore_implement = { + "cuda": (topi.cuda.conv2d_hwnc_tensorcore, topi.cuda.schedule_conv2d_hwnc_tensorcore) +} + +def verify_conv2d_hwnc(batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation=1, devices='cuda', dtype='int4'): + """Test the conv2d with tensorcore for hwnc layout""" + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + # choose dtype from int4, int8 + assert dtype in ['int4', 'int8'] + out_dtype = 'int32' + + in_height = in_width = in_size + + A = te.placeholder((in_height, in_width, batch, in_channel), name='A', dtype=dtype) + W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W', dtype=dtype) + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + @memoize("topi.tests.test_topi_conv2d_hwnc.verify_conv2d_hwnc") + def get_ref_data(): + if dtype == 'int4': + a_np = np.random.randint(low=-8, high=7, size=a_shape).transpose((2, 0, 1, 3)) + w_np = np.random.randint(low=-8, high=7, size=w_shape) + dw_np = topi.testing.dilate_python(w_np.transpose((0, 1, 3, 2)), (1, 1, dilation, dilation)) + elif dtype == 'int8': + a_np = np.random.randint(low=-128, high=127, size=a_shape).transpose((2, 0, 1, 3)).astype(dtype) + w_np = np.random.randint(low=-128, high=127, size=w_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + + c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + return a_np, w_np, c_np + + def convert_int32_into_int4(a_int32): + """ convert int32 values into int4 + Parameters + ---------- + a_int32 : int + + Return + ------ + a_int4 : int + """ + I, J, K, L = a_int32.shape + a_int4 = np.zeros(shape=(I, J, K, L // 8), dtype=np.int32) + for i in range(I): + for j in range(J): + for k in range(K): + for l in range(L // 8): + for m in range(min(8, L-l*8)): + a_int4[i, j, k, l] = a_int4[i, j, k, l] | ((a_int32[i, j, k, l * 8 + m] & 0xf) << ((7 - m) * 4)) + return a_int4 + + a_np, w_np, c_np = get_ref_data() + if dtype == 'int4': + a_np = convert_int32_into_int4(a_np) + w_np = convert_int32_into_int4(w_np) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + if not nvcc.have_tensorcore(ctx.compute_version): + print("skip because gpu does not support Tensor Cores") + return + print("Running on target: %s" % device) + with tvm.target.create(device): + fcompute, fschedule = topi.testing.dispatch(device, _conv2d_hwnc_tensorcore_implement) + C = fcompute(A, W, stride, padding, dilation, dtype, 'int32') + s = fschedule([C]) + + a = tvm.nd.array(a_np.transpose((1, 2, 0, 3)), ctx) + w = tvm.nd.array(w_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, c) + + rtol = 1e-3 + tvm.testing.assert_allclose(c.asnumpy().transpose((2, 0, 1, 3)), c_np, rtol=rtol) + + check_device(devices) + + +def test_conv2d_hwnc_tensorcore(): + """Test the conv2d with tensorcore for hwnc layout""" + verify_conv2d_hwnc(8, 64, 56, 64, 3, 1, 1) + verify_conv2d_hwnc(8, 32, 7, 8, 3, 1, 1) + verify_conv2d_hwnc(8, 64, 56, 64, 1, 1, 0) + verify_conv2d_hwnc(8, 64, 56, 128, 3, 2, 1) + verify_conv2d_hwnc(8, 64, 56, 64, 1, 2, 0) + verify_conv2d_hwnc(8, 128, 28, 128, 3, 1, 1) + verify_conv2d_hwnc(8, 128, 28, 256, 3, 2, 1) + verify_conv2d_hwnc(8, 128, 28, 256, 1, 2, 0) + verify_conv2d_hwnc(8, 256, 14, 256, 3, 1, 1) + verify_conv2d_hwnc(8, 256, 14, 512, 3, 2, 1) + verify_conv2d_hwnc(8, 256, 14, 512, 1, 2, 0) + verify_conv2d_hwnc(8, 512, 9, 512, 3, 1, 1) + +if __name__ == "__main__": + test_conv2d_hwnc_tensorcore() From 2750bfdfccdadfc3043bb7531a54fcc2d5534a49 Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Wed, 22 Jul 2020 14:46:20 +0000 Subject: [PATCH 11/21] remove useless code --- topi/python/topi/cuda/__init__.py | 1 - .../topi/cuda/conv2d_nhwc_tensorcore_int4.py | 341 ------------------ .../test_topi_conv2d_nhwc_tensorcore_int4.py | 319 ---------------- 3 files changed, 661 deletions(-) delete mode 100644 topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py delete mode 100644 topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index f75e72cff7a4..92a6b825f55d 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -46,7 +46,6 @@ from .rcnn import * from .sort import * from .conv2d_nhwc_tensorcore import * -from .conv2d_nhwc_tensorcore_int4 import * from .conv3d_ndhwc_tensorcore import * from .dense_tensorcore import * from .conv2d_hwnc_tensorcore import * \ No newline at end of file diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py deleted file mode 100644 index c772de54b703..000000000000 --- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore_int4.py +++ /dev/null @@ -1,341 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, too-many-locals, too-many-function-args -# pylint: disable=too-many-statements, unused-argument, too-many-arguments -"""Tensorcore template for cuda backend""" -import numpy as np -import tvm -from tvm import te -from tvm import autotvm -from ..util import get_const_tuple, traverse_inline, simplify -from ..nn.pad import pad -from ..nn.util import get_pad_tuple -from topi.cuda.injective import schedule_injective_from_existing -from .tensor_intrin import intrin_wmma_load_matrix_A, intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm - -def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype, out_dtype): - """Compute declaration for tensorcore""" - assert isinstance(stride, int) or len(stride) == 2 - assert isinstance(dilation, int) or len(dilation) == 2 - - if isinstance(stride, int): - stride_h = stride_w = stride - else: - stride_h, stride_w = stride - - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - - if in_dtype == 'int4': - wmma_n = wmma_m = 8 - wmma_k = 32 - else: - wmma_m = 8 - wmma_n = 32 - wmma_k = 16 - - batch, in_height, in_width, in_channels= get_const_tuple(Input.shape) - if in_dtype == 'int4' or in_dtype == 'int8': - kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape) - else: - kernel_h, kernel_w, _, num_filter, _, _ = get_const_tuple(Filter.shape) - if in_dtype == 'int4': - assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0) - else: - assert (batch % 16 == 0 and in_channels % 16 == 0 and num_filter % 16 == 0) or \ - (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0) or \ - (batch % 32 == 0 and in_channels % 16 == 0 and num_filter % 8 == 0), \ - "The shape of (batch, in_channels, num_filter) "\ - "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for fp16 and int8, "\ - "and (8, 32, 8) for int4" - - # compute the output shape - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_top, pad_left, pad_down, pad_right = get_pad_tuple( - padding, (dilated_kernel_h, dilated_kernel_w)) - out_channels = num_filter - out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) - out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) - # Input feature map: (N, H, W, IC, n, ic) - data_shape = (batch // wmma_m, - in_height, - in_width, - in_channels // wmma_k, - wmma_m, - wmma_k) - # Kernel: (H, W, OC, IC, ic, oc) - kernel_shape = (kernel_h, - kernel_w, - out_channels // wmma_n, - in_channels // wmma_k, - wmma_n, - wmma_k) - output_shape = (batch, - out_height, - out_width, - out_channels) - # Reduction axes - kh = te.reduce_axis((0, kernel_h), name='kh') - kw = te.reduce_axis((0, kernel_w), name='kw') - ic = te.reduce_axis((0, in_channels // wmma_k), name='ic') - ii = te.reduce_axis((0, wmma_k), name='ii') - # Algorithm - A_transpose = te.compute(data_shape, - lambda n, h, w, i, nn, ii: Input[n * wmma_m + nn, h, w, i * wmma_k + ii].astype(in_dtype) - ) - Filter_transpose = te.compute(kernel_shape, - lambda kh, kw, o, i, oo, ii: Filter[kh, kw, o * wmma_n + oo, i * wmma_k + ii].astype(in_dtype) - ) - Apad_transpose = te.compute( - (batch // wmma_m, in_height + 2 * padding, in_width + 2 * padding, in_channels // wmma_k, wmma_m, - wmma_k), - lambda n, h, w, i, nn, ii: tvm.tir.if_then_else( - tvm.tir.all(h >= padding, h - padding < in_height, - w >= padding, w - padding < in_width), - A_transpose[n, h - padding, w - padding, i, nn, ii], tvm.tir.const(0., in_dtype)), - name='Apad') - Conv = te.compute((batch // wmma_m, out_height, out_width, out_channels // wmma_n, wmma_m, wmma_n), - lambda n, h, w, o, nn, oo: te.sum( - Apad_transpose[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("int32") * - Filter_transpose[kh, kw, o, ic, oo, ii].astype("int32"), - axis=[ic, kh, kw, ii]), - name="Conv") - Out = te.compute(output_shape, - lambda n, h, w, o: Conv[n // wmma_m, h, w, o // wmma_n, n % wmma_m, o % wmma_n], - name="Out", tag="conv2d_nhwc_tensorcore_int4") - return Out - - -def schedule_nhwc_tensorcore_cuda_int4(cfg, s, Out): - """Schedule tensorcore template""" - Conv = s[Out].op.input_tensors[0] - ic, kh, kw, ii = s[Conv].op.reduce_axis - out_dtype = Conv.dtype - Apad, kernel = s[Conv].op.input_tensors - A_transpose = s[Apad].op.input_tensors[0] - - in_dtype = Apad.dtype - batch, _, _, _, _, _ = get_const_tuple(Conv.shape) - # if in_dtype == 'int4': - # _, _, _, out_channels, _, _ = get_const_tuple(kernel.shape) - # else: - # _, _, out_channels, _, _, _ = get_const_tuple(kernel.shape) - # inline the pad and dtype transform - - block_x = te.thread_axis('blockIdx.x') - block_y = te.thread_axis('blockIdx.y') - block_z = te.thread_axis('blockIdx.z') - thread_x = te.thread_axis('threadIdx.x') - thread_y = te.thread_axis('threadIdx.y') - thread_z = te.thread_axis('threadIdx.z') - - # Designate the memory hierarchy - AS = s.cache_read(Apad, 'shared', [Conv]) - WS = s.cache_read(kernel, 'shared', [Conv]) - AF = s.cache_read(AS, 'wmma.matrix_a', [Conv]) - WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) - ConvF = s.cache_write(Conv, 'wmma.accumulator') - - # Schedule for autotvm - cfg.define_knob("block_row_warps", [1, 2]) - cfg.define_knob("block_col_warps", [1, 2]) - cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16]) - cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16]) - cfg.define_knob("chunk", [1, 2, 4, 8]) - cfg.define_knob("vector_ws", [1, 8]) - cfg.define_knob("inline_pad", [0, 1]) - cfg.define_knob("vector_as", [1, 4, 8, 16]) - cfg.define_knob("split_block_k", [1, 2, 4, 8]) - - # fallback support - target = tvm.target.Target.current() - if cfg.is_fallback: - ref_log = autotvm.tophub.load_reference_log( - target.target_name, target.model, 'conv2d_nhwc_tensorcore_int4.cuda') - cfg.fallback_with_reference_log(ref_log) - - block_row_warps = cfg["block_row_warps"].val - block_col_warps = cfg["block_col_warps"].val - warp_row_tiles = cfg["warp_row_tiles"].val - warp_col_tiles = cfg["warp_col_tiles"].val - chunk = cfg["chunk"].val - # offset = cfg["offset"].val - vector_ws = cfg["vector_ws"].val - vector_as = cfg["vector_as"].val - split_block_k = cfg["split_block_k"].val - inline_pad = cfg["inline_pad"].val - block_row_warps = 1 - block_col_warps = 1 - warp_row_tiles = 8 - warp_col_tiles = 4 - chunk = 2 - vector_ws = 1 - inline_pad = 0 - vector_as = 16 - split_block_k = 1 - - # inline_pad = 0 - - with tvm.target.create('cuda'): - schedule_injective_from_existing(s, Out) - schedule_injective_from_existing(s, kernel) - # schedule_injective_from_existing(s, Apad) - # schedule_injective_from_existing(s, A_transpose) - # s[kernel].compute_inline() - # s[Apad].compute_inline() - s[A_transpose].compute_inline() - # s[Out].compute_inline() - - if inline_pad: - s[Apad].compute_inline() - else: - with tvm.target.create('cuda'): - schedule_injective_from_existing(s, Apad) - - if in_dtype == 'int4': - wmma_m = wmma_n = 8 - wmma_k = 32 - else: - # if (batch % 16 == 0 and out_channels % 16 == 0): - # cfg.define_knob("wmma_m", [16, 8, 32]) - # elif (batch % 8 == 0 and out_channels % 32 == 0): - # cfg.define_knob("wmma_m", [8, 16, 32]) - # elif (batch % 32 == 0 and out_channels % 8 == 0): - # cfg.define_knob("wmma_m", [32, 16, 8]) - # wmma_m = cfg["wmma_m"].val - wmma_m = 8 - wmma_k = 16 - if wmma_m == 16: - wmma_n = 16 - elif wmma_m == 8: - wmma_n = 32 - elif wmma_m == 32: - wmma_n = 8 - - warp_size = 32 - - nc, hc, wc, oc, nnc, ooc = Conv.op.axis - block_k = s[Conv].fuse(hc, wc) - block_k, sub_block_k = s[Conv].split(block_k, factor=split_block_k) - nc, nci = s[Conv].split(nc, factor=warp_row_tiles) - block_i, nc = s[Conv].split(nc, factor=block_row_warps) - oc, oci = s[Conv].split(oc, factor=warp_col_tiles) - block_j, oc = s[Conv].split(oc, factor=block_col_warps) - s[Conv].reorder(block_k, block_i, block_j, sub_block_k, nc, oc, nci, oci, nnc, ooc) - s[Conv].bind(block_k, block_z) - s[Conv].bind(block_i, block_x) - s[Conv].bind(block_j, block_y) - s[Conv].bind(nc, thread_y) - s[Conv].bind(oc, thread_z) - - # Schedule local computation - s[ConvF].compute_at(s[Conv], oc) - n, h, w, o, nnf, oof = ConvF.op.axis - ko, ki = s[ConvF].split(ic, factor=chunk) - s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii) - - # Move intermediate computation into each output compute tile - s[AF].compute_at(s[ConvF], kw) - s[WF].compute_at(s[ConvF], kw) - - # Schedule for A's share memory - s[AS].compute_at(s[ConvF], kh) - n, h, w, i, nn, ii = AS.op.axis - tx, xo = s[AS].split(n, nparts=block_row_warps) - ty, yo = s[AS].split(xo, nparts=block_col_warps) - t = s[AS].fuse(nn, ii) - to, ti = s[AS].split(t, nparts=warp_size) - ti, _t = s[AS].split(ti, factor=vector_as) - s[AS].bind(tx, thread_y) - s[AS].bind(ty, thread_z) - s[AS].bind(to, thread_x) - s[AS].vectorize(_t) - - # Schedule for W's share memory - s[WS].compute_at(s[ConvF], kw) - kh, kw, ic, o, ii, oo = WS.op.axis - tx, xo = s[WS].split(o, nparts=block_row_warps) - ty, yo = s[WS].split(xo, nparts=block_col_warps) - t = s[WS].fuse(ii, oo) - to, ti = s[WS].split(t, nparts=warp_size) - ti, _t = s[WS].split(ti, factor=vector_ws) - s[WS].bind(tx, thread_y) - s[WS].bind(ty, thread_z) - s[WS].bind(to, thread_x) - s[WS].vectorize(ti) - - shape = (wmma_m, wmma_n, wmma_k) - - AS_shape = (wmma_m, wmma_k) - AL_shape = (wmma_m, wmma_k) - WS_shape = (wmma_n, wmma_k) - WL_shape = (wmma_n, wmma_k) - CL_shape = (wmma_m, wmma_n) - CS_shape = (wmma_m, wmma_n) - - AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype) - WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype) - k_gemm = te.reduce_axis((0, wmma_k), name="k") - CL_compute = te.compute(CL_shape, lambda ii, jj: - te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) * \ - WL_gemm[jj, k_gemm].astype(out_dtype), axis=k_gemm), - name='C') - AL_strides = [wmma_k, 1] - AS_strides = [wmma_k, 1] - WL_strides = [wmma_k, 1] - WS_strides = [wmma_k, 1] - CL_strides = [wmma_n, 1] - CS_strides = [wmma_n, 1] - - s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, - "row_major", AS_shape, AL_shape, in_dtype)) - s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, - "col_major", WS_shape, WL_shape, in_dtype)) - s[Conv].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, - shape, out_dtype, CL_shape, CS_shape)) - s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, - WL_strides, CL_strides, shape)) - - - N, OH, OW, CO, nn, mm = get_const_tuple(Conv.shape) - if in_dtype == 'int4': - KH, KW, _, CI, _, ci = get_const_tuple(kernel.shape) - else: - KH, KW, _, CI, _, ci = get_const_tuple(kernel.shape) - cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * ci * nn * mm) - - -@autotvm.register_topi_compute("conv2d_nhwc_tensorcore_int4.cuda") -def conv2d_nhwc_tensorcore_int4(cfg, data, kernel, strides, padding, dilation, in_dtype, out_dtype): - """Compute conv2d with tensorcore for NCHW layout""" - return nhwc_tensorcore_cuda(cfg, data, kernel, strides, padding, dilation, in_dtype, out_dtype) - - -@autotvm.register_topi_schedule("conv2d_nhwc_tensorcore_int4.cuda") -def schedule_conv2d_nhwc_tensorcore_int4(cfg, outs): - """TOPI schedule callback""" - s = te.create_schedule([x.op for x in outs]) - def _callback(op): - if 'conv2d_nhwc_tensorcore_int4' in op.tag: - schedule_nhwc_tensorcore_cuda_int4(cfg, s, op.output(0)) - - traverse_inline(s, outs[0].op, _callback) - return s - diff --git a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py deleted file mode 100644 index 8c22842715c4..000000000000 --- a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore_int4.py +++ /dev/null @@ -1,319 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, too-many-locals, too-many-arguments -"""Example code to do convolution.""" - -import numpy as np -import tvm -import os -import topi -import topi.testing -from tvm import te, autotvm -from tvm.contrib.pickle_memoize import memoize -from tvm.contrib import nvcc -from topi.nn.util import get_pad_tuple -from topi.util import get_const_tuple - -TASK="conv_int4" - -USE_MANUAL_CODE = False - -# @tvm.register_func -# def tvm_callback_cuda_compile(code): -# ptx = nvcc.compile_cuda(code, target="ptx") -# return ptx - -def write_code(code, fname): - with open(fname, "w") as f: - f.write(code) - -@tvm.register_func -def tvm_callback_cuda_postproc(code): - if not os.path.exists("perf"): - os.mkdir("perf") - write_code(code, "perf/%s_generated.cu" % TASK) - if USE_MANUAL_CODE: - code = open("perf/%s_manual.cu" % TASK).read() - return code - - -_conv2d_nhwc_tensorcore_implement = { - "cuda": (topi.cuda.conv2d_nhwc_tensorcore_int4, topi.cuda.schedule_conv2d_nhwc_tensorcore_int4) -} - - -def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, - padding, dilation=1, add_bias=False, add_relu=False, devices='cuda'): - """Test the conv2d with tensorcore for nhwc layout""" - pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) - padding_sum = pad_top + pad_left + pad_bottom + pad_right - print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % ( - batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) - - # choose dtype from int4, int8 and float16 - dtype = 'int8' - out_dtype = 'int32' - if dtype == 'int4': - wmma_n = wmma_m = 8 - wmma_k = 32 - else: - wmma_m = 32 - wmma_n = 8 - wmma_k = 16 - in_height = in_width = in_size - - A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype=dtype) - - # A = te.placeholder((batch // wmma_m, in_height, in_width, in_channel // wmma_k, wmma_m, wmma_k), name='A', dtype=dtype) - if dtype == 'int4' or dtype == 'int8': - W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W', dtype=dtype) - # W = te.placeholder((kernel, kernel, in_channel // wmma_k, num_filter // wmma_n , wmma_n, wmma_k), name='W', dtype=dtype) - else: - W = te.placeholder((kernel, kernel, in_channel // wmma_k, num_filter // wmma_n , wmma_n, wmma_k), name='W', dtype=dtype) - - bias = te.placeholder((1, 1, 1, num_filter), name='bias', dtype=out_dtype) - - a_shape = get_const_tuple(A.shape) - w_shape = get_const_tuple(W.shape) - bias_shape = get_const_tuple(bias.shape) - # dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d_nhwc.verify_conv2d_nhwc") - def get_ref_data(): - np.random.seed(5) - if dtype == 'float16': - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = np.random.uniform(size=bias_shape).astype(out_dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - elif dtype == 'int4': - a_np = np.random.randint(low=1, high=7, size=a_shape).astype(np.int32) - b_np = np.random.randint(low=1, high=7, size=bias_shape).astype(np.int32) - w_np = np.random.randint(low=1, high=7, size=w_shape).astype(np.int32) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)).transpose((0, 1, 3, 2)) - elif dtype == 'int8': - a_np = np.random.randint(low=1, high=7, size=a_shape).astype(dtype) - w_np = np.random.randint(low=1, high=7, size=w_shape).astype(dtype) - b_np = np.random.randint(low=1, high=7, size=bias_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)).transpose((0, 1, 3, 2)) - - c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) - if add_bias: - # b_np = np.random.uniform(size=bias_shape).astype(out_dtype) - c_np += b_np - if add_relu: - c_np = np.maximum(c_np, 0) - return a_np, w_np, b_np, c_np - - def convert_int32_into_int4(a_int32): - """ convert int32 values into int4 - Parameters - ---------- - a_int32 : int - - Return - ------ - a_int4 : int - """ - I, J, K, L = a_int32.shape - a_int4 = np.zeros(shape=(I, J, K, L // 8), dtype=np.int32) - for i in range(I): - for j in range(J): - for k in range(K): - for l in range(L // 8): - for a in range(8): - a_int4[i,j,k,l] = a_int4[i,j,k,l] | ((a_int32[i,j,k,l * 8 + a] & 0xf) << ((7 - a) * 4)) - return a_int4 - - a_np, w_np, b_np, c_np = get_ref_data() - - if dtype == 'int4' or dtype == 'int8': - # a_np_tvm = a_np.reshape((batch // wmma_m, - # wmma_m, - # in_height, - # in_width, - # in_channel // wmma_k, - # wmma_k)).transpose((0,2,3,4,1,5)) - # w_np = w_np.reshape((kernel, - # kernel, - # in_channel // wmma_k, - # wmma_k, - # num_filter // wmma_n, - # wmma_n)).transpose((0,1,2,4,5,3)) - if dtype == 'int4': - a_np = convert_int32_into_int4(a_np) - # b_np = convert_int32_into_int4(b_np) - w_np = convert_int32_into_int4(w_np) - - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - if not nvcc.have_tensorcore(ctx.compute_version): - print("skip because gpu does not support Tensor Cores") - return - print("Running on target: %s" % device) - with tvm.target.create(device): - fcompute, fschedule = topi.testing.dispatch(device, _conv2d_nhwc_tensorcore_implement) - if dtype == 'float16': - C = fcompute(A, W, stride, padding, dilation, dtype, 'float') - else: - C = fcompute(A, W, stride, padding, dilation, dtype, 'int32') - if add_bias: - C = topi.add(C, bias) - if add_relu: - C = topi.nn.relu(C) - s = fschedule([C]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) - if add_bias: - func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( - batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) - func(a, w, b, c) - else: - print(tvm.lower(s, [A, W, C], simple_mode=True)) - func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( - batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) - func(a, w, c) - dev_module = func.imported_modules[0] - # print(dev_module.get_source()) - # warm up - evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) - evaluator(a, w, c) - print('Time cost of this operator: %f ms' % (evaluator(a, w, c).mean * 1000)) - - rtol = 1e-3 - # print(c.asnumpy().sum(), c_np.sum()) - tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol) - - # # #Tuning the performance - # import logging, sys - # logging.getLogger('autotvm').setLevel(logging.DEBUG) - # logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) - - # log_filename = "conv2d_" + dtype +"_nhwc_tensorcore_kernel_shape_%d_%d_%d_%d_%d_%d_%d_%d.log" % (batch, in_channel, in_size, num_filter, kernel, stride, - # padding, dilation) - # tmp_log_file = log_filename + '.temp' - # num_trial = 1000 - # task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, - # padding, dilation) - # task = autotvm.create('conv2d_nhwc_tensorcore_int4.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) - # print(task.config_space) - - # measure_option = autotvm.measure_option( - # builder='local', - # runner=autotvm.LocalRunner(number=5)) - - # tuner = autotvm.tuner.XGBTuner(task, feature_type='knob') - # num_trial = min(num_trial, len(task.config_space)) - # with tvm.target.build_config(): - # tuner.tune(n_trial=num_trial, - # measure_option=measure_option, - # callbacks=[autotvm.callback.progress_bar(num_trial, prefix=task_name), - # autotvm.callback.log_to_file(tmp_log_file)]) - - # dispatch_context = autotvm.apply_history_best(tmp_log_file) - # best_config = dispatch_context.query(task.target, task.workload) - # print("\nBest config:") - # print(best_config) - - # #pick the best record to a cache file - # autotvm.record.pick_best(tmp_log_file, log_filename) - # os.remove(tmp_log_file) - - # with autotvm.apply_graph_best(log_filename): - # with tvm.target.create(device): - # func = tvm.build(s, [A, W, C], device, name="conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, - # padding, dilation)) - # evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) - # print('Time cost of this operator after tuning: %f ms' % (evaluator(a, w, c).mean * 1000)) - - check_device(devices) - - -def test_conv2d_nhwc_tensorcore(): - """Test the conv2d with tensorcore for nhwc layout""" - # verify_conv2d_nhwc(64, 64, 56, 64, 3, 1, 1) - # verify_conv2d_nhwc(64, 64, 56, 64, 1, 1, 0) - # verify_conv2d_nhwc(64, 64, 56, 128, 3, 2, 1) - # verify_conv2d_nhwc(64, 64, 56, 64, 1, 2, 0) - # verify_conv2d_nhwc(64, 128, 28, 128, 3, 1, 1) - # verify_conv2d_nhwc(64, 128, 28, 256, 3, 2, 1) - # verify_conv2d_nhwc(64, 128, 28, 256, 1, 2, 0) - # verify_conv2d_nhwc(64, 256, 14, 256, 3, 1, 1) - # verify_conv2d_nhwc(64, 256, 14, 512, 3, 2, 1) - # verify_conv2d_nhwc(64, 256, 14, 512, 1, 2, 0) - # verify_conv2d_nhwc(64, 512, 7, 512, 3, 1, 1) - - # verify_conv2d_nhwc(32, 64, 56, 64, 3, 1, 1) - # verify_conv2d_nhwc(32, 64, 56, 64, 1, 1, 0) - # verify_conv2d_nhwc(32, 64, 56, 128, 3, 2, 1) - # verify_conv2d_nhwc(32, 64, 56, 64, 1, 2, 0) - # verify_conv2d_nhwc(32, 128, 28, 128, 3, 1, 1) - # verify_conv2d_nhwc(32, 128, 28, 256, 3, 2, 1) - # verify_conv2d_nhwc(32, 128, 28, 256, 1, 2, 0) - # verify_conv2d_nhwc(32, 256, 14, 256, 3, 1, 1) - # verify_conv2d_nhwc(32, 256, 14, 512, 3, 2, 1) - # verify_conv2d_nhwc(32, 256, 14, 512, 1, 2, 0) - # verify_conv2d_nhwc(32, 512, 7, 512, 3, 1, 1) - - # verify_conv2d_nhwc(16, 64, 56, 64, 3, 1, 1) - # verify_conv2d_nhwc(16, 64, 56, 64, 1, 1, 0) - # verify_conv2d_nhwc(16, 64, 56, 128, 3, 2, 1) - # verify_conv2d_nhwc(16, 64, 56, 64, 1, 2, 0) - # verify_conv2d_nhwc(16, 128, 28, 128, 3, 1, 1) - # verify_conv2d_nhwc(16, 128, 28, 256, 3, 2, 1) - # verify_conv2d_nhwc(16, 128, 28, 256, 1, 2, 0) - # verify_conv2d_nhwc(16, 256, 14, 256, 3, 1, 1) - # verify_conv2d_nhwc(16, 256, 14, 512, 3, 2, 1) - # verify_conv2d_nhwc(16, 256, 14, 512, 1, 2, 0) - # verify_conv2d_nhwc(16, 512, 7, 512, 3, 1, 1) - - verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) - # verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) - # verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) - # verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) - # verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1) - # verify_conv2d_nhwc(8, 128, 28, 256, 3, 2, 1) - # verify_conv2d_nhwc(8, 128, 28, 256, 1, 2, 0) - # verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1) - # verify_conv2d_nhwc(8, 256, 14, 512, 3, 2, 1) - # verify_conv2d_nhwc(8, 256, 14, 512, 1, 2, 0) - # verify_conv2d_nhwc(8, 512, 7, 512, 3, 1, 1) - - - # verify_conv2d_nhwc(32, 1024, 14, 256, 1, 1, 1) - - # verify_conv2d_nhwc(16, 128, 7, 128, 7, 1, 3) - # verify_conv2d_nhwc(16, 160, 7, 160, 7, 1, 3) - - # verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_bias=True) - # verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True) - # verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True, add_bias=True) - - # verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, (3, 3, 2, 2)) - # verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, "SAME") - # verify_conv2d_nhwc(16, 48, 35, 48, 5, 1, "VALID") - # verify_conv2d_nhwc(16, 48, 56, 48, 3, 1, (1, 1, 1, 1)) - # verify_conv2d_nhwc(16, 64, 28, 64, 3, 1, (1, 1, 1, 1)) - - -if __name__ == "__main__": - test_conv2d_nhwc_tensorcore() From b5dd8d77422a82f52388c48b8dc3b53f3b3f8125 Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Wed, 22 Jul 2020 15:18:19 +0000 Subject: [PATCH 12/21] remove useless code --- .../topi/cuda/conv2d_nhwc_tensorcore.py | 178 +++++---------- topi/python/topi/cuda/tensor_intrin.py | 2 +- .../test_topi_conv2d_nhwc_tensorcore.py | 213 ++---------------- 3 files changed, 79 insertions(+), 314 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py index 5b1776941bd1..b9ccdacf0734 100644 --- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py +++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py @@ -30,7 +30,7 @@ from .tensor_intrin import intrin_wmma_gemm -def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype, out_dtype): +def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtype): """Compute declaration for tensorcore""" assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 @@ -46,20 +46,12 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = get_const_tuple(Input.shape) - if in_dtype == 'int4': - kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape) - else: - kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape) - - if in_dtype == 'int4': - assert (batch % 8 == 0 and in_channel % 32 == 0 and num_filter % 8 == 0) - else: - assert (batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0) or \ + kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape) + assert (batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0) or \ (batch % 8 == 0 and in_channel % 16 == 0 and num_filter % 32 == 0) or \ (batch % 32 == 0 and in_channel % 16 == 0 and num_filter % 8 == 0), \ "The shape of (batch, in_channel, num_filter) "\ - "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for fp16 and int8, "\ - "and (8, 32, 8) for int4" + "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 @@ -78,26 +70,16 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, in_dtype # convert data type of input feature maps and weights TransPaddedInput = te.compute( PaddedInput.shape, - lambda n, h, w, c: PaddedInput[n, h, w, c].astype(in_dtype)) + lambda n, h, w, c: PaddedInput[n, h, w, c].astype('float16')) TransFilter = te.compute( - Filter.shape, lambda h, w, i, o: Filter[h, w, i, o].astype(in_dtype)) - if in_dtype == 'int4': - Output = te.compute( - (batch, out_height, out_width, out_channel), - lambda nn, yy, xx, ff: te.sum( - TransPaddedInput[nn, yy * stride_h + ry * dilation_h, - xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * - TransFilter[ry, rx, ff, rc].astype(out_dtype), axis=[ry, rx, rc]), - name="Conv2dOutput", tag="conv2d_nhwc_tensorcore") - else: - Output = te.compute( - (batch, out_height, out_width, out_channel), - lambda nn, yy, xx, ff: te.sum( - TransPaddedInput[nn, yy * stride_h + ry * dilation_h, - xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * - TransFilter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), - name="Conv2dOutput", tag="conv2d_nhwc_tensorcore") - + Filter.shape, lambda h, w, i, o: Filter[h, w, i, o].astype('float16')) + Output = te.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: te.sum( + TransPaddedInput[nn, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + TransFilter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), + name="Conv2dOutput", tag="conv2d_nhwc_tensorcore") return Output @@ -108,12 +90,9 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv): trans_paddata, kernel = s[Conv].op.input_tensors in_dtype = trans_paddata.dtype batch, _, _, _ = get_const_tuple(Conv.shape) - - if in_dtype == 'int4': - _, _, out_channels, _ = get_const_tuple(kernel.shape) - else: - _, _, _, out_channels = get_const_tuple(kernel.shape) + _, _, _, out_channels = get_const_tuple(kernel.shape) paddata = s[trans_paddata].op.input_tensors + # inline the pad and dtype transform s[trans_paddata].compute_inline() s[kernel].compute_inline() @@ -141,20 +120,21 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv): cfg.define_knob("warp_row_tiles", [1, 2, 4]) cfg.define_knob("warp_col_tiles", [1, 2, 4]) cfg.define_knob("chunk", [1, 2, 4, 8]) - if in_dtype == 'int8': - cfg.define_knob("offset", [0, 16]) - elif in_dtype == 'int4': - cfg.define_knob("offset", [0]) - else: - cfg.define_knob("offset", [0, 8]) + cfg.define_knob("offset", [0, 8]) cfg.define_knob("vector_width", [1, 2, 4, 8]) - # cfg.define_knob("vector_width", [1]) + + if (batch % 16 == 0 and out_channels % 16 == 0): + cfg.define_knob("wmma_m", [16, 8, 32]) + elif (batch % 8 == 0 and out_channels % 32 == 0): + cfg.define_knob("wmma_m", [8, 16, 32]) + elif (batch % 32 == 0 and out_channels % 8 == 0): + cfg.define_knob("wmma_m", [32, 16, 8]) # fallback support target = tvm.target.Target.current() if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.target_name, target.model, 'conv2d_nhwc_tensorcore.cuda') + target.id.name, target.model, 'conv2d_nhwc_tensorcore.cuda') cfg.fallback_with_reference_log(ref_log) block_row_warps = cfg["block_row_warps"].val @@ -163,34 +143,16 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv): warp_col_tiles = cfg["warp_col_tiles"].val chunk = cfg["chunk"].val offset = cfg["offset"].val + wmma_m = cfg["wmma_m"].val vector_width = cfg["vector_width"].val - block_row_warps = 1 - block_col_warps = 4 - warp_row_tiles = 2 - warp_col_tiles = 2 - chunk = 1 - offset = 0 - vector_width = 1 - - if in_dtype == 'int4': - wmma_m = wmma_n = 8 - wmma_k = 32 - else: - if (batch % 16 == 0 and out_channels % 16 == 0): - cfg.define_knob("wmma_m", [16, 8, 32]) - elif (batch % 8 == 0 and out_channels % 32 == 0): - cfg.define_knob("wmma_m", [8, 16, 32]) - elif (batch % 32 == 0 and out_channels % 8 == 0): - cfg.define_knob("wmma_m", [32, 16, 8]) - wmma_m = cfg["wmma_m"].val - wmma_m = 16 - wmma_k = 16 - if wmma_m == 16: - wmma_n = 16 - elif wmma_m == 8: - wmma_n = 32 - elif wmma_m == 32: - wmma_n = 8 + + wmma_k = 16 + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 warp_size = 32 @@ -206,20 +168,17 @@ def get_strides(extents): return [np.prod(extents[i:]).tolist() for i in range(len(extents))] AS_align = chunk * wmma_k + offset - if in_dtype == 'int4': - WS_align = chunk * warp_col_tiles * block_col_warps * wmma_k + offset - WL_strides = get_strides([wmma_k * warp_col_tiles, 1]) - else: - WS_align = warp_col_tiles * block_col_warps * wmma_n + offset - WL_strides = get_strides([wmma_n * warp_col_tiles, 1]) + WS_align = warp_col_tiles * block_col_warps * wmma_n + offset block_factor_n = wmma_m * warp_row_tiles * block_row_warps block_factor_o = wmma_n * warp_col_tiles * block_col_warps CS_align = block_factor_o + offset AS_strides = get_strides([1, 1, AS_align, 1]) AL_strides = get_strides([1, 1, wmma_k, 1]) WS_strides = get_strides([WS_align, 1]) + WL_strides = get_strides([wmma_n * warp_col_tiles, 1]) CL_strides = get_strides([1, 1, wmma_n * warp_col_tiles, 1]) CS_strides = get_strides([1, 1, CS_align, 1]) + # Schedule for output nc, hc, wc, oc = output.op.axis block_k = s[output].fuse(hc, wc) @@ -263,8 +222,8 @@ def get_strides(extents): ko, ki = s[ConvF].split(ic, factor=chunk) s[ConvF].reorder(kh, kw, ko, ki, n, o, nnf, oof, ii) - s[AF].compute_at(s[ConvF], n) - s[WF].compute_at(s[ConvF], n) + s[AF].compute_at(s[ConvF], ki) + s[WF].compute_at(s[ConvF], ki) # Schedule wmma load n, h, w, i = AF.op.axis @@ -272,20 +231,11 @@ def get_strides(extents): i, ii = s[AF].split(i, factor=wmma_k) s[AF].reorder(n, i, nn, ii) - # kh, kw, i, o = WF.op.axis - if in_dtype == 'int4': - kh, kw, o, i = WF.op.axis - # print('kh, kw, o, i', kh, kw, o, i) - i, ii = s[WF].split(i, factor=wmma_k) - o, oo = s[WF].split(o, factor=wmma_n) - s[WF].reorder(o, i, oo) - s[WF].reorder(o, i, oo, ii) - else: - kh, kw, i, o = WF.op.axis - i, ii = s[WF].split(i, factor=wmma_k) - o, oo = s[WF].split(o, factor=wmma_n) - s[WF].reorder(o, i, oo) - s[WF].reorder(i, o, ii, oo) + kh, kw, i, o = WF.op.axis + i, ii = s[WF].split(i, factor=wmma_k) + o, oo = s[WF].split(o, factor=wmma_n) + s[WF].reorder(o, i, oo) + s[WF].reorder(i, o, ii, oo) s[WS].compute_at(s[ConvF], ko) s[AS].compute_at(s[ConvF], ko) @@ -322,54 +272,37 @@ def get_strides(extents): # tensorize the wmma process AS_shape = (wmma_m, 1, 1, wmma_k) AL_shape = (wmma_m, 1, 1, wmma_k) - if in_dtype == 'int4': - WS_shape = (wmma_n, wmma_k) - WL_shape = (wmma_n, wmma_k) - else: - WS_shape = (wmma_k, wmma_n) - WL_shape = (wmma_k, wmma_n) + WS_shape = (wmma_k, wmma_n) + WL_shape = (wmma_k, wmma_n) CL_shape = (wmma_m, 1, 1, wmma_n) CS_shape = (wmma_m, 1, 1, wmma_n) AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype) WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k") - if in_dtype == 'int4': - CL_compute = te.compute(CL_shape, lambda ii, t0, t1, jj: - te.sum(AL_gemm[ii, t0, t1, k_gemm].astype(out_dtype) * \ - WL_gemm[jj, k_gemm].astype(out_dtype), axis=k_gemm), - name='C') - else: - CL_compute = te.compute(CL_shape, lambda ii, t0, t1, jj: - te.sum(AL_gemm[ii, t0, t1, k_gemm].astype(out_dtype) * \ - WL_gemm[k_gemm, jj].astype(out_dtype), axis=k_gemm), - name='C') + CL_compute = te.compute(CL_shape, lambda ii, t0, t1, jj: + te.sum(AL_gemm[ii, t0, t1, k_gemm].astype(out_dtype) * \ + WL_gemm[k_gemm, jj].astype(out_dtype), axis=k_gemm), + name='C') s[AF].tensorize(nn, intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, "row_major", AS_shape, AL_shape, in_dtype)) - if in_dtype == 'int4': - s[WF].tensorize(oo, intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, - "col_major", WS_shape, WL_shape, in_dtype)) - else: - s[WF].tensorize(ii, intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, - "row_major", WS_shape, WL_shape, in_dtype)) + s[WF].tensorize(ii, intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, + "row_major", WS_shape, WL_shape, in_dtype)) s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, shape, out_dtype, CL_shape, CS_shape)) s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, WL_strides, CL_strides, shape)) N, OH, OW, CO = get_const_tuple(output.shape) - if in_dtype == 'int4': - KH, KW, _, CI = get_const_tuple(kernel.shape) - else: - KH, KW, CI, _ = get_const_tuple(kernel.shape) + KH, KW, CI, _ = get_const_tuple(kernel.shape) cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW) @autotvm.register_topi_compute("conv2d_nhwc_tensorcore.cuda") -def conv2d_nhwc_tensorcore(cfg, data, kernel, strides, padding, dilation, in_dtype, out_dtype): +def conv2d_nhwc_tensorcore(cfg, data, kernel, strides, padding, dilation, out_dtype): """Compute conv2d with tensorcore for NCHW layout""" - return nhwc_tensorcore_cuda(cfg, data, kernel, strides, padding, dilation, in_dtype, out_dtype) + return nhwc_tensorcore_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype) @autotvm.register_topi_schedule("conv2d_nhwc_tensorcore.cuda") @@ -382,5 +315,4 @@ def _callback(op): schedule_nhwc_tensorcore_cuda(cfg, s, op.output(0)) traverse_inline(s, outs[0].op, _callback) - return s - + return s \ No newline at end of file diff --git a/topi/python/topi/cuda/tensor_intrin.py b/topi/python/topi/cuda/tensor_intrin.py index f1be5e506bd5..7e56a9678b43 100644 --- a/topi/python/topi/cuda/tensor_intrin.py +++ b/topi/python/topi/cuda/tensor_intrin.py @@ -146,7 +146,7 @@ def intrin_wmma_store_matrix(strides_dst, strides_from, shape, out_dtype, A_shap offset_factor=8) C = te.compute(C_shape, lambda *i: A(*i), name='C') BC = tvm.tir.decl_buffer(C.shape, C.dtype, - scope='global', strides=strides_dst, + scope='shared', strides=strides_dst, data_alignment=32, offset_factor=8) def intrin_func(ins, outs): diff --git a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py index 84fd4ea48927..ddd0f8e3ce1f 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py @@ -19,37 +19,14 @@ import numpy as np import tvm -import os import topi import topi.testing -from tvm import te, autotvm +from tvm import te from tvm.contrib.pickle_memoize import memoize from tvm.contrib import nvcc from topi.nn.util import get_pad_tuple from topi.util import get_const_tuple -TASK="conv_int4" - -USE_MANUAL_CODE = False - -# @tvm.register_func -# def tvm_callback_cuda_compile(code): -# ptx = nvcc.compile_cuda(code, target="ptx") -# return ptx - -def write_code(code, fname): - with open(fname, "w") as f: - f.write(code) - -@tvm.register_func -def tvm_callback_cuda_postproc(code): - if not os.path.exists("perf"): - os.mkdir("perf") - write_code(code, "perf/%s_generated.cu" % TASK) - if USE_MANUAL_CODE: - code = open("perf/%s_manual.cu" % TASK).read() - return code - _conv2d_nhwc_tensorcore_implement = { "cuda": (topi.cuda.conv2d_nhwc_tensorcore, topi.cuda.schedule_conv2d_nhwc_tensorcore) @@ -64,77 +41,32 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % ( batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) - # choose dtype from int4, int8 and float16 - dtype = 'int4' - out_dtype = 'int32' - in_height = in_width = in_size - A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype=dtype) - if dtype == 'int4': - W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W', dtype=dtype) - else: - W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype=dtype) - - bias = te.placeholder((1, 1, 1, num_filter), name='bias', dtype=out_dtype) + A = te.placeholder((batch, in_height, in_width, in_channel), name='A') + W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W') + bias = te.placeholder((1, 1, 1, num_filter), name='bias') a_shape = get_const_tuple(A.shape) w_shape = get_const_tuple(W.shape) bias_shape = get_const_tuple(bias.shape) - # dtype = A.dtype + dtype = A.dtype @memoize("topi.tests.test_topi_conv2d_nhwc.verify_conv2d_nhwc") def get_ref_data(): - if dtype == 'float16': - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = np.random.uniform(size=bias_shape).astype(out_dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - elif dtype == 'int4': - a_np = np.random.randint(low=1, high=7, size=a_shape) - b_np = np.random.randint(low=1, high=7, size=bias_shape) - w_np = np.random.randint(low=1, high=7, size=w_shape) - dw_np = topi.testing.dilate_python(w_np.transpose((0, 1, 3, 2)), (1, 1, dilation, dilation)) - elif dtype == 'int8': - a_np = np.random.randint(low=1, high=7, size=a_shape).astype(dtype) - w_np = np.random.randint(low=1, high=7, size=w_shape).astype(dtype) - b_np = np.random.randint(low=1, high=7, size=bias_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) if add_bias: - # b_np = np.random.uniform(size=bias_shape).astype(out_dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) c_np += b_np if add_relu: c_np = np.maximum(c_np, 0) return a_np, w_np, b_np, c_np - - def convert_int32_into_int4(a_int32): - """ convert int32 values into int4 - Parameters - ---------- - a_int32 : int - - Return - ------ - a_int4 : int - """ - I, J, K, L = a_int32.shape - a_int4 = np.zeros(shape=(I, J, K, L // 8), dtype=np.int32) - # for g in range(G): - for i in range(I): - for j in range(J): - for k in range(K): - for l in range(L // 8): - for m in range(min(8, L-l*8)): - a_int4[i, j, k, l] = a_int4[i, j, k, l] | ((a_int32[i, j, k, l * 8 + m] & 0xf) << ((7 - m) * 4)) - return a_int4 a_np, w_np, b_np, c_np = get_ref_data() - if dtype == 'int4': - a_np = convert_int32_into_int4(a_np) - b_np = convert_int32_into_int4(b_np) - w_np = convert_int32_into_int4(w_np) def check_device(device): ctx = tvm.context(device, 0) @@ -147,10 +79,7 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): fcompute, fschedule = topi.testing.dispatch(device, _conv2d_nhwc_tensorcore_implement) - if dtype == 'float16': - C = fcompute(A, W, stride, padding, dilation, dtype, 'float') - else: - C = fcompute(A, W, stride, padding, dilation, dtype, 'int32') + C = fcompute(A, W, stride, padding, dilation, 'float32') if add_bias: C = topi.add(C, bias) if add_relu: @@ -166,128 +95,32 @@ def check_device(device): batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, b, c) else: - # print(tvm.lower(s, [A, W, C], simple_mode=True)) func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, c) - # warm up - evaluator = func.time_evaluator(func.entry_name, ctx, number=50, repeat=20) - evaluator(a, w, c) - print('Time cost of this operator: %f ms' % (evaluator(a, w, c).mean * 1000)) rtol = 1e-3 tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol) - # # #Tuning the performance - # import logging, sys - # logging.getLogger('autotvm').setLevel(logging.DEBUG) - # logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) - - # log_filename = "conv2d_int4_nhwc_tensorcore.log" - # tmp_log_file = log_filename + '.temp' - # num_trial = 1000 - # task_name = "conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, - # padding, dilation) - # task = autotvm.create('conv2d_nhwc_tensorcore.cuda', args=[A, W, stride, padding, dilation, dtype, out_dtype], target=device) - # print(task.config_space) - - # measure_option = autotvm.measure_option( - # builder='local', - # runner=autotvm.LocalRunner(number=5)) - - # tuner = autotvm.tuner.XGBTuner(task) - # num_trial = min(num_trial, len(task.config_space)) - # with tvm.target.build_config(): - # tuner.tune(n_trial=num_trial, - # measure_option=measure_option, - # callbacks=[autotvm.callback.progress_bar(num_trial, prefix=task_name), - # autotvm.callback.log_to_file(tmp_log_file)]) - - # dispatch_context = autotvm.apply_history_best(tmp_log_file) - # best_config = dispatch_context.query(task.target, task.workload) - # print("\nBest config:") - # print(best_config) - - # #pick the best record to a cache file - # autotvm.record.pick_best(tmp_log_file, log_filename) - # os.remove(tmp_log_file) - - # with autotvm.apply_graph_best(log_filename): - # with tvm.target.create(device): - # func = tvm.build(s, [A, W, C], device, name="conv2d_nhwc_tensorcore_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, - # padding, dilation)) - # evaluator = func.time_evaluator(func.entry_name, ctx, number=100, repeat=10) - # print('Time cost of this operator after tuning: %f ms' % (evaluator(a, w, c).mean * 1000)) - check_device(devices) def test_conv2d_nhwc_tensorcore(): """Test the conv2d with tensorcore for nhwc layout""" - # verify_conv2d_nhwc(64, 64, 56, 64, 3, 1, 1) - # verify_conv2d_nhwc(64, 64, 56, 64, 1, 1, 0) - # verify_conv2d_nhwc(64, 64, 56, 128, 3, 2, 1) - # verify_conv2d_nhwc(64, 64, 56, 64, 1, 2, 0) - # verify_conv2d_nhwc(64, 128, 28, 128, 3, 1, 1) - # verify_conv2d_nhwc(64, 128, 28, 256, 3, 2, 1) - # verify_conv2d_nhwc(64, 128, 28, 256, 1, 2, 0) - # verify_conv2d_nhwc(64, 256, 14, 256, 3, 1, 1) - # verify_conv2d_nhwc(64, 256, 14, 512, 3, 2, 1) - # verify_conv2d_nhwc(64, 256, 14, 512, 1, 2, 0) - # verify_conv2d_nhwc(64, 512, 7, 512, 3, 1, 1) - - # verify_conv2d_nhwc(32, 64, 56, 64, 3, 1, 1) - # verify_conv2d_nhwc(32, 64, 56, 64, 1, 1, 0) - # verify_conv2d_nhwc(32, 64, 56, 128, 3, 2, 1) - # verify_conv2d_nhwc(32, 64, 56, 64, 1, 2, 0) - # verify_conv2d_nhwc(32, 128, 28, 128, 3, 1, 1) - # verify_conv2d_nhwc(32, 128, 28, 256, 3, 2, 1) - # verify_conv2d_nhwc(32, 128, 28, 256, 1, 2, 0) - # verify_conv2d_nhwc(32, 256, 14, 256, 3, 1, 1) - # verify_conv2d_nhwc(32, 256, 14, 512, 3, 2, 1) - # verify_conv2d_nhwc(32, 256, 14, 512, 1, 2, 0) - # verify_conv2d_nhwc(32, 512, 7, 512, 3, 1, 1) - - # verify_conv2d_nhwc(16, 64, 56, 64, 3, 1, 1) - # verify_conv2d_nhwc(16, 64, 56, 64, 1, 1, 0) - # verify_conv2d_nhwc(16, 64, 56, 128, 3, 2, 1) - # verify_conv2d_nhwc(16, 64, 56, 64, 1, 2, 0) - # verify_conv2d_nhwc(16, 128, 28, 128, 3, 1, 1) - # verify_conv2d_nhwc(16, 128, 28, 256, 3, 2, 1) - # verify_conv2d_nhwc(16, 128, 28, 256, 1, 2, 0) - # verify_conv2d_nhwc(16, 256, 14, 256, 3, 1, 1) - # verify_conv2d_nhwc(16, 256, 14, 512, 3, 2, 1) - # verify_conv2d_nhwc(16, 256, 14, 512, 1, 2, 0) - # verify_conv2d_nhwc(16, 512, 7, 512, 3, 1, 1) - - verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1) - # verify_conv2d_nhwc(8, 64, 56, 64, 1, 1, 0) - verify_conv2d_nhwc(8, 64, 56, 128, 3, 2, 1) - # verify_conv2d_nhwc(8, 64, 56, 64, 1, 2, 0) - # verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1) - # verify_conv2d_nhwc(8, 128, 28, 256, 3, 2, 1) - # verify_conv2d_nhwc(8, 128, 28, 256, 1, 2, 0) - # verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1) - # verify_conv2d_nhwc(8, 256, 14, 512, 3, 2, 1) - # verify_conv2d_nhwc(8, 256, 14, 512, 1, 2, 0) - # verify_conv2d_nhwc(8, 512, 7, 512, 3, 1, 1) - - - # verify_conv2d_nhwc(32, 1024, 14, 256, 1, 1, 1) - - # verify_conv2d_nhwc(16, 128, 7, 128, 7, 1, 3) - # verify_conv2d_nhwc(16, 160, 7, 160, 7, 1, 3) + verify_conv2d_nhwc(16, 16, 14, 16, 3, 1, 1) + verify_conv2d_nhwc(16, 128, 7, 128, 7, 1, 3) + verify_conv2d_nhwc(16, 160, 7, 160, 7, 1, 3) - # verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_bias=True) - # verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True) - # verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True, add_bias=True) + verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_bias=True) + verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True) + verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True, add_bias=True) - # verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, (3, 3, 2, 2)) - # verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, "SAME") - # verify_conv2d_nhwc(16, 48, 35, 48, 5, 1, "VALID") - # verify_conv2d_nhwc(16, 48, 56, 48, 3, 1, (1, 1, 1, 1)) - # verify_conv2d_nhwc(16, 64, 28, 64, 3, 1, (1, 1, 1, 1)) + verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, (3, 3, 2, 2)) + verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, "SAME") + verify_conv2d_nhwc(16, 48, 35, 48, 5, 1, "VALID") + verify_conv2d_nhwc(16, 48, 56, 48, 3, 1, (1, 1, 1, 1)) + verify_conv2d_nhwc(16, 64, 28, 64, 3, 1, (1, 1, 1, 1)) if __name__ == "__main__": - test_conv2d_nhwc_tensorcore() + test_conv2d_nhwc_tensorcore() \ No newline at end of file From f6843dfc33c57ed62e88effaff97dfc296e7a1ac Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Wed, 22 Jul 2020 15:22:04 +0000 Subject: [PATCH 13/21] remove useless code --- topi/python/topi/cuda/conv2d_hwnc_tensorcore.py | 1 - topi/python/topi/cuda/conv2d_nhwc_tensorcore.py | 4 ++-- topi/python/topi/cuda/tensor_intrin.py | 3 ++- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py b/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py index 621ed343413f..1a8974528c1b 100644 --- a/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py +++ b/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py @@ -402,4 +402,3 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s - diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py index b9ccdacf0734..790db0fe89a0 100644 --- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py +++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py @@ -134,7 +134,7 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv): target = tvm.target.Target.current() if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.id.name, target.model, 'conv2d_nhwc_tensorcore.cuda') + target.target_name, target.model, 'conv2d_nhwc_tensorcore.cuda') cfg.fallback_with_reference_log(ref_log) block_row_warps = cfg["block_row_warps"].val @@ -315,4 +315,4 @@ def _callback(op): schedule_nhwc_tensorcore_cuda(cfg, s, op.output(0)) traverse_inline(s, outs[0].op, _callback) - return s \ No newline at end of file + return s diff --git a/topi/python/topi/cuda/tensor_intrin.py b/topi/python/topi/cuda/tensor_intrin.py index 7e56a9678b43..436b5055f30f 100644 --- a/topi/python/topi/cuda/tensor_intrin.py +++ b/topi/python/topi/cuda/tensor_intrin.py @@ -222,4 +222,5 @@ def update(): return update(), init(), update() - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) \ No newline at end of file + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + \ No newline at end of file From b0cffecb53f004aab3c66cc24b8cb7bcc026115f Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Wed, 22 Jul 2020 15:26:30 +0000 Subject: [PATCH 14/21] remove useless code --- topi/python/topi/cuda/tensor_intrin.py | 1 - topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/topi/python/topi/cuda/tensor_intrin.py b/topi/python/topi/cuda/tensor_intrin.py index 436b5055f30f..3941c00cc464 100644 --- a/topi/python/topi/cuda/tensor_intrin.py +++ b/topi/python/topi/cuda/tensor_intrin.py @@ -223,4 +223,3 @@ def update(): return update(), init(), update() return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) - \ No newline at end of file diff --git a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py index ddd0f8e3ce1f..cc327849caea 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py @@ -123,4 +123,4 @@ def test_conv2d_nhwc_tensorcore(): if __name__ == "__main__": - test_conv2d_nhwc_tensorcore() \ No newline at end of file + test_conv2d_nhwc_tensorcore() From ae0d45c6d5ff541e074dee7193f843a24245fc4c Mon Sep 17 00:00:00 2001 From: GaryYuyjl Date: Thu, 23 Jul 2020 02:04:41 +0000 Subject: [PATCH 15/21] fix int8 transpose --- topi/python/topi/cuda/conv2d_hwnc_tensorcore.py | 4 +--- topi/tests/python/test_topi_conv2d_hwnc_tensorcore.py | 8 +++----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py b/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py index 1a8974528c1b..568887795e84 100644 --- a/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py +++ b/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py @@ -142,8 +142,7 @@ def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp pad_after = [pad_down, pad_right, 0, 0, 0, 0] pad_data = pad(packed_data, pad_before, pad_after, name="pad_data") - - Conv = te.compute((out_height, out_width, batch // wmma_m, out_channels // wmma_n,wmma_m, wmma_n), + Conv = te.compute((out_height, out_width, batch // wmma_m, out_channels // wmma_n, wmma_m, wmma_n), lambda h, w, n, o, nn, oo: te.sum( (pad_data[h * stride_h + kh, w * stride_w + kw, n, ic, nn, ii].astype('int32') * packed_kernel[kh, kw, o, ic, oo, ii].astype('int32')), @@ -306,7 +305,6 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): s[AS].compute_at(s[ConvF], ko) else: s[AS].compute_at(s[ConvF], kh) - # s[AS].compute_at(s[ConvF], kh) h, w, n, i, nn, ii = AS.op.axis tx, xo = s[AS].split(n, nparts=block_row_warps) ty, yo = s[AS].split(xo, nparts=block_col_warps) diff --git a/topi/tests/python/test_topi_conv2d_hwnc_tensorcore.py b/topi/tests/python/test_topi_conv2d_hwnc_tensorcore.py index 659138a40b20..3d6848005574 100644 --- a/topi/tests/python/test_topi_conv2d_hwnc_tensorcore.py +++ b/topi/tests/python/test_topi_conv2d_hwnc_tensorcore.py @@ -41,7 +41,6 @@ def verify_conv2d_hwnc(batch, in_channel, in_size, num_filter, kernel, stride, batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) # choose dtype from int4, int8 assert dtype in ['int4', 'int8'] - out_dtype = 'int32' in_height = in_width = in_size @@ -59,7 +58,7 @@ def get_ref_data(): elif dtype == 'int8': a_np = np.random.randint(low=-128, high=127, size=a_shape).transpose((2, 0, 1, 3)).astype(dtype) w_np = np.random.randint(low=-128, high=127, size=w_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + dw_np = topi.testing.dilate_python(w_np.transpose((0, 1, 3, 2)), (1, 1, dilation, dilation)) c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) return a_np, w_np, c_np @@ -119,9 +118,8 @@ def check_device(device): def test_conv2d_hwnc_tensorcore(): """Test the conv2d with tensorcore for hwnc layout""" - verify_conv2d_hwnc(8, 64, 56, 64, 3, 1, 1) - verify_conv2d_hwnc(8, 32, 7, 8, 3, 1, 1) - verify_conv2d_hwnc(8, 64, 56, 64, 1, 1, 0) + verify_conv2d_hwnc(8, 64, 56, 64, 3, 1, 1, dtype='int8') + verify_conv2d_hwnc(8, 64, 56, 64, 1, 1, 0, dtype='int4') verify_conv2d_hwnc(8, 64, 56, 128, 3, 2, 1) verify_conv2d_hwnc(8, 64, 56, 64, 1, 2, 0) verify_conv2d_hwnc(8, 128, 28, 128, 3, 1, 1) From 211d5303403b84015f980828869a48d0c67c5187 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Fri, 14 Aug 2020 09:14:15 +0000 Subject: [PATCH 16/21] fix assert --- topi/python/topi/cuda/conv2d_alter_op.py | 4 ++-- topi/python/topi/cuda/conv2d_hwnc_tensorcore.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_alter_op.py b/topi/python/topi/cuda/conv2d_alter_op.py index 8d0f72d1eb86..e4b64dcac593 100644 --- a/topi/python/topi/cuda/conv2d_alter_op.py +++ b/topi/python/topi/cuda/conv2d_alter_op.py @@ -170,10 +170,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): "group_conv2d_NCHWc_int8.cuda") dispatch_ctx.update(target, new_workload, cfg) return relay.nn.conv2d(*inputs, **new_attrs) - + if topi_tmpl == "conv2d_HWNCnc_tensorcore.cuda": assert data_layout == "HWNC" and kernel_layout == "HWOI" - + assert float(tvm.gpu(0).compute_version) >= 7.5 H, W, N, CI = get_const_tuple(data.shape) KH, KW, CO, _ = get_const_tuple(kernel.shape) diff --git a/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py b/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py index 568887795e84..64888feb75ce 100644 --- a/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py +++ b/topi/python/topi/cuda/conv2d_hwnc_tensorcore.py @@ -84,11 +84,9 @@ def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp if in_dtype in ['int4', 'uint4']: assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0) else: - assert (batch % 16 == 0 and in_channels % 16 == 0 and num_filter % 16 == 0) or \ - (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0) or \ - (batch % 32 == 0 and in_channels % 16 == 0 and num_filter % 8 == 0), \ + assert (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0), \ "The shape of (batch, in_channels, num_filter) "\ - "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for fp16 and int8, "\ + "must be multiple of (8, 16, 32) for int8, "\ "and (8, 32, 8) for int4" # compute the output shape From 95f000cfb740bea38917ee0c7775dd2b042ee7c1 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Sun, 16 Aug 2020 10:06:19 +0000 Subject: [PATCH 17/21] add asf header --- .../tvm/topi/cuda/conv2d_hwnc_tensorcore.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py index 46b44861ff1b..b74bc1415abe 100644 --- a/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py +++ b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py @@ -1,3 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-function-args +# pylint: disable=too-many-statements, unused-argument, too-many-arguments +"""Tensorcore template for cuda backend""" import numpy as np import tvm from tvm import te From b6faf5b4dc3fc6d0fea4295d6c4256117080d775 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Sun, 16 Aug 2020 13:18:24 +0000 Subject: [PATCH 18/21] CI --- src/target/source/codegen_c.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 4cd6f58ceff9..649eb86bc07e 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -619,8 +619,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) this->PrintType(l->dtype.element_of(), os); os << " *)" << this->GetVarID(l->buffer_var.get()) << " + " << "(";; this->PrintExpr(l->index, os); - if (l->dtype.bits() == 4 || - (l->dtype.bits() == 1 && l->dtype.is_int())) { + if (l->dtype.bits() == 4 || (l->dtype.bits() == 1 && l->dtype.is_int())) { os << " / " << (32 / l->dtype.bits()); } os << "))"; From 859ecb9965ff2cbed43173772a199b76eca05a11 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Sun, 16 Aug 2020 13:29:59 +0000 Subject: [PATCH 19/21] CI --- src/target/source/codegen_c.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 649eb86bc07e..2f19d6e126ad 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -617,7 +617,8 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) CHECK(op->args.size() == 1 && l); os << "(("; this->PrintType(l->dtype.element_of(), os); - os << " *)" << this->GetVarID(l->buffer_var.get()) << " + " << "(";; + os << " *)" << this->GetVarID(l->buffer_var.get()) << " + " + << "("; this->PrintExpr(l->index, os); if (l->dtype.bits() == 4 || (l->dtype.bits() == 1 && l->dtype.is_int())) { os << " / " << (32 / l->dtype.bits()); From 46244779f0d05b0f6196f2c5293468f3dcf51204 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Sun, 16 Aug 2020 14:09:29 +0000 Subject: [PATCH 20/21] CI --- python/tvm/relay/op/strategy/cuda.py | 3 +- python/tvm/topi/cuda/conv2d_alter_op.py | 4 +- .../tvm/topi/cuda/conv2d_hwnc_tensorcore.py | 106 +++++++++++------- 3 files changed, 68 insertions(+), 45 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 97f5bd2b857e..454bef038262 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -181,7 +181,8 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): out_channels = oc_chunk * oc_block_factor else: _, _, out_channels, _ = get_const_tuple(kernel.shape) - if topi.cuda.is_shape_tensorcore_direct_qualified(batch=N, in_channels=in_channels, num_filter=out_channels, in_dtype=data.dtype): + if topi.cuda.is_shape_tensorcore_direct_qualified( + batch=N, in_channels=in_channels, num_filter=out_channels, in_dtype=data.dtype): strategy.add_implementation( wrap_compute_conv2d(topi.cuda.conv2d_hwnc_tensorcore), wrap_topi_schedule(topi.cuda.schedule_conv2d_hwnc_tensorcore), diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index e4b64dcac593..f07ef984025f 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -187,12 +187,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): ic_block_factor = 32 oc_block_factor = 8 else: - new_attrs['kernel_layout']= 'HWOI32o16i' + new_attrs['kernel_layout'] = 'HWOI32o16i' ic_block_factor = 16 oc_block_factor = 32 new_kernel = te.placeholder((KH, KW, CO // oc_block_factor, CI // ic_block_factor, - oc_block_factor, ic_block_factor), dtype=kernel.dtype) + oc_block_factor, ic_block_factor), dtype=kernel.dtype) new_workload = autotvm.task.args_to_workload( [data, new_kernel, strides, padding, dilation, out_dtype], diff --git a/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py index b74bc1415abe..4aedee385cec 100644 --- a/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py +++ b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py @@ -17,15 +17,18 @@ # pylint: disable=invalid-name, too-many-locals, too-many-function-args # pylint: disable=too-many-statements, unused-argument, too-many-arguments """Tensorcore template for cuda backend""" -import numpy as np import tvm from tvm import te from tvm import autotvm +from tvm.topi.cuda.injective import schedule_injective_from_existing from ..util import get_const_tuple, traverse_inline, simplify, tag from ..nn.pad import pad from ..nn.util import get_pad_tuple -from tvm.topi.cuda.injective import schedule_injective_from_existing -from .tensor_intrin import intrin_wmma_load_matrix_A, intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm +from .tensor_intrin import intrin_wmma_load_matrix_A +from .tensor_intrin import intrin_wmma_load_matrix_W +from .tensor_intrin import intrin_wmma_store_matrix +from .tensor_intrin import intrin_wmma_gemm + def unpack_HWNCnc_to_hwnc(packed_out, out_dtype): """Unpack conv2d_hwnc output from layout hwncnc to hwnc @@ -52,20 +55,24 @@ def unpack_HWNCnc_to_hwnc(packed_out, out_dtype): unpacked_out = \ te.compute(oshape, lambda h, w, n, o: - packed_out[h, w, idxdiv(n, wmma_m), idxdiv(o, wmma_n), idxmod(n, wmma_m), idxmod(o, wmma_n)] + packed_out[h, w, idxdiv(n, wmma_m), idxdiv(o, wmma_n), + idxmod(n, wmma_m), idxmod(o, wmma_n)] .astype(out_dtype), name='output_unpack', - tag=tag.INJECTIVE+",unpack_hwncc") + tag=tag.INJECTIVE + ",unpack_hwncc") return unpacked_out + def conv2d_hwnc_tensorcore(data, kernel, strides, padding, dilation, in_dtype, out_dtype='int32'): """Compute conv2d internally using conv2d_nchwc layout for int8 dtype""" assert data.dtype in ('int4', 'uint4', 'int8', 'uint8') assert kernel.dtype in ('int4', 'uint4', 'int8', 'uint8') # assert data.dtype == kernel.dtype - packed_out = hwnc_tensorcore_cuda(data, kernel, strides, padding, dilation, out_dtype) + packed_out = hwnc_tensorcore_cuda( + data, kernel, strides, padding, dilation, out_dtype) return unpack_HWNCnc_to_hwnc(packed_out, out_dtype) + @autotvm.register_topi_compute("conv2d_HWNCnc_tensorcore.cuda") def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtype='int32'): """Compute declaration for tensorcore""" @@ -95,18 +102,20 @@ def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp pre_computed = len(Filter.shape) == 6 in_height, in_width, batch, in_channels = get_const_tuple(Input.shape) if pre_computed: - kernel_h, kernel_w, oc_chunk, ic_chunk, oc_block_factor, ic_block_factor = get_const_tuple(Filter.shape) + kernel_h, kernel_w, oc_chunk, _, oc_block_factor, _\ + = get_const_tuple(Filter.shape) num_filter = oc_block_factor * oc_chunk else: kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape) if in_dtype in ['int4', 'uint4']: - assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0) + assert (batch % 8 == 0 and in_channels % + 32 == 0 and num_filter % 8 == 0) else: assert (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0), \ - "The shape of (batch, in_channels, num_filter) "\ - "must be multiple of (8, 16, 32) for int8, "\ - "and (8, 32, 8) for int4" + "The shape of (batch, in_channels, num_filter) "\ + "must be multiple of (8, 16, 32) for int8, "\ + "and (8, 32, 8) for int4" # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 @@ -116,10 +125,13 @@ def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp padding, (dilated_kernel_h, dilated_kernel_w)) out_channels = num_filter - out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) - out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + out_height = simplify( + (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + + pad_left + pad_right) // stride_w + 1) - cfg.add_flop(2 * batch * out_height * out_width * out_channels * in_channels * kernel_h * kernel_w) + cfg.add_flop(2 * batch * out_height * out_width * + out_channels * in_channels * kernel_h * kernel_w) # Input feature map: (H, W, N, IC, n, ic) data_shape = (in_height, @@ -146,29 +158,33 @@ def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp if pre_computed: packed_kernel = Filter else: - packed_kernel = te.compute(kernel_shape, - lambda kh, kw, o, i, oo, ii: Filter[kh, kw, o * wmma_n + oo, i * wmma_k + ii], - name="packed_kernel" - ) + packed_kernel = te.compute(kernel_shape, lambda kh, kw, o, i, oo, ii: + Filter[kh, kw, o * wmma_n + oo, i * wmma_k + ii], + name="packed_kernel" + ) packed_data = te.compute(data_shape, - lambda h, w, n, i, nn, ii: Input[h, w, n * wmma_m + nn, i * wmma_k + ii] - ) + lambda h, w, n, i, nn, ii: Input[h, + w, n * wmma_m + nn, i * wmma_k + ii] + ) pad_before = [pad_top, pad_left, 0, 0, 0, 0] pad_after = [pad_down, pad_right, 0, 0, 0, 0] pad_data = pad(packed_data, pad_before, pad_after, name="pad_data") - Conv = te.compute((out_height, out_width, batch // wmma_m, out_channels // wmma_n, wmma_m, wmma_n), - lambda h, w, n, o, nn, oo: te.sum( - (pad_data[h * stride_h + kh, w * stride_w + kw, n, ic, nn, ii].astype('int32') * - packed_kernel[kh, kw, o, ic, oo, ii].astype('int32')), - axis=[ic, kh, kw, ii]), - name="Conv", tag="conv2d_HWNCnc_tensorcore") + Conv = te.compute((out_height, out_width, batch // wmma_m, + out_channels // wmma_n, wmma_m, wmma_n), + lambda h, w, n, o, nn, oo: te.sum( + (pad_data[h * stride_h + kh, w * stride_w + kw, + n, ic, nn, ii].astype('int32') * + packed_kernel[kh, kw, o, ic, oo, ii].astype('int32')), + axis=[ic, kh, kw, ii]), + name="Conv", tag="conv2d_HWNCnc_tensorcore") return Conv def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): + """Schedule tensorcore template""" packed_data, packed_kernel = s[Conv].op.input_tensors ic, kh, kw, ii = s[Conv].op.reduce_axis pad_data = s[packed_data].op.input_tensors[0] @@ -200,7 +216,8 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): if isinstance(packed_kernel.op, te.tensor.ComputeOp) and packed_kernel.name == "packed_kernel": if autotvm.GLOBAL_SCOPE.in_tuning: - s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region") + s[packed_kernel].pragma( + s[packed_kernel].op.axis[0], "debug_skip_region") else: with tvm.target.create('cuda'): schedule_injective_from_existing(s, packed_kernel) @@ -259,7 +276,7 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): # Schedule for output if len(s[output].op.axis) == 4: - hc, wc, nc, oc, = output.op.axis + hc, wc, nc, oc, = output.op.axis nc, nnc = s[output].split(nc, factor=wmma_m) oc, ooc = s[output].split(oc, factor=wmma_n) else: @@ -268,12 +285,14 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): kernel_scope, hc = s[output].split(hc, nparts=1) block_k = s[output].fuse(hc, wc) - block_k, split_block_k = s[output].split(block_k, factor=split_block_k_nums) + block_k, split_block_k = s[output].split( + block_k, factor=split_block_k_nums) nc, nci = s[output].split(nc, factor=warp_row_tiles) block_i, nc = s[output].split(nc, factor=block_row_warps) oc, oci = s[output].split(oc, factor=warp_col_tiles) block_j, oc = s[output].split(oc, factor=block_col_warps) - s[output].reorder(block_k, split_block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc) + s[output].reorder(block_k, split_block_k, block_i, + block_j, nc, oc, nci, oci, nnc, ooc) t = s[output].fuse(nnc, ooc) _, tx = s[output].split(t, factor=warp_size) s[output].bind(block_k, block_z) @@ -296,7 +315,7 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): # Schedule local computation s[ConvF].compute_at(s[OL], oc) - h, w, n, o, nnf, oof = ConvF.op.axis + _, _, n, o, nnf, oof = ConvF.op.axis ko, ki = s[ConvF].split(ic, factor=chunk) s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii) @@ -322,9 +341,9 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): s[AS].compute_at(s[ConvF], ko) else: s[AS].compute_at(s[ConvF], kh) - h, w, n, i, nn, ii = AS.op.axis + _, _, n, _, nn, ii = AS.op.axis tx, xo = s[AS].split(n, nparts=block_row_warps) - ty, yo = s[AS].split(xo, nparts=block_col_warps) + ty, _ = s[AS].split(xo, nparts=block_col_warps) t = s[AS].fuse(nn, ii) to, ti = s[AS].split(t, nparts=warp_size) ti, _t = s[AS].split(ti, factor=vector_as) @@ -345,7 +364,7 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): s[WS].compute_at(s[ConvF], kw) kh, kw, ic, o, ii, oo = WS.op.axis tx, xo = s[WS].split(o, nparts=block_row_warps) - ty, yo = s[WS].split(xo, nparts=block_col_warps) + ty, _ = s[WS].split(xo, nparts=block_col_warps) t = s[WS].fuse(ii, oo) to, ti = s[WS].split(t, nparts=warp_size) ti, _t = s[WS].split(ti, factor=vector_ws) @@ -381,8 +400,9 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): WL_gemm = te.placeholder(WL_shape, name='B', dtype=kernel_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k") CL_compute = te.compute(CL_shape, lambda ii, jj: - te.sum((AL_gemm[ii, k_gemm].astype('int32')* WL_gemm[jj, k_gemm].astype('int32')), axis=k_gemm), - name='C') + te.sum((AL_gemm[ii, k_gemm].astype( + 'int32') * WL_gemm[jj, k_gemm].astype('int32')), axis=k_gemm), + name='C') AL_strides = [wmma_k, 1] AS_strides = [wmma_k, 1] @@ -391,26 +411,28 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): CL_strides = [wmma_n, 1] CS_strides = [wmma_n, 1] + s[AF].tensorize(AF.op.axis[-2], + intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, + "row_major", AS_shape, AL_shape, data_dtype)) - s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, - "row_major", AS_shape, AL_shape, data_dtype)) - - s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, - "col_major", WS_shape, WL_shape, kernel_dtype)) + s[WF].tensorize(WF.op.axis[-2], + intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, + "col_major", WS_shape, WL_shape, kernel_dtype)) s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, shape, out_dtype, CL_shape, CS_shape)) - s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, WL_strides, CL_strides, shape)) return s + @autotvm.register_topi_schedule("conv2d_HWNCnc_tensorcore.cuda") def schedule_conv2d_hwnc_tensorcore(cfg, outs): """TOPI schedule callback""" s = te.create_schedule([x.op for x in outs]) + def _callback(op): if 'conv2d_HWNCnc_tensorcore' in op.tag: schedule_hwnc_tensorcore_cuda(cfg, s, op.output(0)) From a048fe21f51980c45bbbd385477f122bfa381075 Mon Sep 17 00:00:00 2001 From: GaryYuyjl <1035194528@qq.com> Date: Mon, 17 Aug 2020 01:16:10 +0000 Subject: [PATCH 21/21] fix bug fix bug --- python/tvm/relay/op/strategy/cuda.py | 3 +++ python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py | 5 ++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 454bef038262..4b50937fc838 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -188,6 +188,9 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_conv2d_hwnc_tensorcore), name="conv2d_hwnc_tensorcore_direct.cuda", plevel=20) + else: + raise RuntimeError("Unsupported shape for conv2d HWNC.\ + Need to satisfy tensor core schedule.") elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: assert kernel_layout == "OIHW4o4i" strategy.add_implementation( diff --git a/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py index 4aedee385cec..592613ffcf92 100644 --- a/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py +++ b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py @@ -64,10 +64,9 @@ def unpack_HWNCnc_to_hwnc(packed_out, out_dtype): def conv2d_hwnc_tensorcore(data, kernel, strides, padding, dilation, in_dtype, out_dtype='int32'): - """Compute conv2d internally using conv2d_nchwc layout for int8 dtype""" + """"Compute conv2d with tensorcore for HWNC layout with int8/int4""" assert data.dtype in ('int4', 'uint4', 'int8', 'uint8') assert kernel.dtype in ('int4', 'uint4', 'int8', 'uint8') - # assert data.dtype == kernel.dtype packed_out = hwnc_tensorcore_cuda( data, kernel, strides, padding, dilation, out_dtype) return unpack_HWNCnc_to_hwnc(packed_out, out_dtype) @@ -141,7 +140,7 @@ def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp wmma_m, wmma_k) - # Kernel: (H, W, OC, IC, ic, oc) + # Kernel: (H, W, OC, IC, oc, ic) kernel_shape = (kernel_h, kernel_w, out_channels // wmma_n,