From 360de2233bd978e65bdacf6bdd95b40266aed3ec Mon Sep 17 00:00:00 2001 From: Mercy Date: Tue, 13 Feb 2018 21:26:56 +0800 Subject: [PATCH 1/4] add winograd for mali --- topi/python/topi/mali/conv2d.py | 222 ++++++++++++++++++++++++++++++-- 1 file changed, 209 insertions(+), 13 deletions(-) diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index ff67e0503f4f..4be4161540f8 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -2,6 +2,8 @@ """conv2d schedule on ARM Mali GPU""" from __future__ import absolute_import as _abs + +import numpy as np import tvm from .. import generic @@ -63,7 +65,23 @@ def transpose(s, tensor, readers): s[tmp].compute_inline() return s.cache_write(tmp, "global"), tmp -@conv2d.register("mali") +def const_array(data, name): + """ convert an const array to tvm tensor""" + row, col = data.shape + dtype = str(data.dtype) + + def select_array(i, j): + now = tvm.const(0.0, dtype) + for ii in range(row): + for jj in range(col): + now = tvm.select(tvm.all(i % row == ii, j % col == jj), + tvm.const(data[ii][jj], dtype), + now) + return now + return tvm.compute(data.shape, select_array, name=name) + + +@conv2d.register(["mali"]) def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): """Conv2D operator for ARM Mali GPU backend. @@ -94,10 +112,23 @@ def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 assert data.dtype == kernel.dtype, "Do not support inputs with different data types now." out_dtype = data.dtype - if util.get_const_int(kernel.shape[2]) == 1: + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + kernel_shape = util.get_const_tuple(kernel.shape) + data_shape = util.get_const_tuple(data.shape) + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + + gemm_factor = 4 + + if (kernel_shape[2:4] == (3, 3) and (HPAD, WPAD) == (1, 1) and kernel_shape[0] >= 64 + and data_shape[2] * data_shape[3] // 4 % gemm_factor == 0 and (HSTR, WSTR) == (1, 1)): + return _decl_winograd(data, kernel, stride, padding, layout, out_dtype) + elif kernel_shape[2:4] == (1, 1): return _decl_im2col(data, kernel, stride, padding, layout, out_dtype) else: - return _decl_direct(data, kernel, stride, padding, layout, out_dtype) + return _decl_spatialpack(data, kernel, stride, padding, layout, out_dtype) @generic.schedule_conv2d_nchw.register(["mali"]) def schedule_conv2d_nchw(outs): @@ -129,14 +160,17 @@ def traverse(op): if 'im2col_conv_output' in op.tag: _schedule_im2col_conv2d(s, op) - if 'direct_conv_output' in op.tag: - _schedule_direct_conv2d(s, op) + if 'spatialpack_conv_output' in op.tag: + _schedule_spatialpack_conv2d(s, op) + + if 'winograd_conv_output' in op.tag: + _schedule_winograd(s, op) traverse(outs[0].op) return s -def _decl_direct(data, kernel, stride, padding, layout, out_dtype): - """declare the direct method (spatial packing) for conv2d""" +def _decl_spatialpack(data, kernel, stride, padding, layout, out_dtype): + """declare the spatialpack method (spatial packing) for conv2d""" _, CI, IH, IW = [util.get_const_int(x) for x in data.shape] CO, _, KH, KW = [util.get_const_int(x) for x in kernel.shape] HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) @@ -207,12 +241,12 @@ def _decl_direct(data, kernel, stride, padding, layout, out_dtype): output = tvm.compute(oshape, lambda n, co, h, w: conv[n][co//VC][h/VH][w//VW][h%VH][w%VW][co%VC], - name='output_unpack', tag='direct_conv_output') + name='output_unpack', tag='spatialpack_conv_output') return output -def _schedule_direct_conv2d(s, op): - """schedule the direct method (spatial packing) for conv2d""" +def _schedule_spatialpack_conv2d(s, op): + """schedule the spatialpack method (spatial packing) for conv2d""" # get ops and tensors output = op.output(0) output_height = util.get_const_int(output.shape[2]) @@ -294,8 +328,6 @@ def _schedule_direct_conv2d(s, op): _, co, oh, ow = s[output].op.axis tile_and_bind3d(s, output, co, oh, ow, num_thread, 1, last) - #print(tvm.lower(s, [data, kernel, output], simple_mode=True)) - def _decl_im2col(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): """declare the Im2Col method for conv2d""" _, CI, IH, IW = [x.value for x in data.shape] @@ -476,4 +508,168 @@ def _schedule_im2col_conv2d(s, op): s[output].vectorize(vw) fuse_and_bind(s, output, [n, co, h, w]) - #print(tvm.lower(s, [data, kernel], simple_mode=True)) +def _decl_winograd(data, kernel, stride, padding, layout, out_dtype): + """schedule the winograd fast convolution F(2x2, 3x3) for conv2d""" + N, CI, H, W = [util.get_const_int(x) for x in data.shape] + CO, CI, KH, KW = [util.get_const_int(x) for x in kernel.shape] + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + + assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 and KH == 3 and KW == 3 + data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") + + B_data = np.array([ + [1, 0, 0, 0], + [0, 1, -1, 1], + [-1, 1, 1, 0], + [0, 0, 0, -1] + ], out_dtype) + + G_data = np.array([ + [1, 0, 0], + [1.0/2, 1.0/2, 1.0/2], + [1.0/2, -1.0/2, 1.0/2], + [0, 0, 1], + ], out_dtype) + + A_data = np.array([ + [1, 0], + [1, 1], + [1, -1], + [0, -1], + ], out_dtype) + + m = 2 + r = 3 + alpha = m + r - 1 + K = CO + C = CI + + nH, nW = (H + m-1) // m, (W + m-1) // m + P = N * nH * nW + + bna, bnb = 4, 4 + if data.dtype == 'float16' and P % (bnb * 2) == 0: + bnb *= 2 + assert K % bna == 0 and P % bnb == 0 + + # pack input tile + input_tile = tvm.compute((C, P // bnb, alpha, alpha, bnb), + lambda c, b, eps, nu, bb: + data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps][(b*bnb+bb) % nW * m + nu], + name='d') + + # transform kernel + G = const_array(G_data, 'G') + r_kh = tvm.reduce_axis((0, KH), 'r_kh') + r_kw = tvm.reduce_axis((0, KW), 'r_kw') + U = tvm.compute((alpha, alpha, K // bna, C, bna), lambda eps, nu, k, c, kk: + tvm.sum(kernel[k * bna + kk][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), + name='U') + + # transform image + B = const_array(B_data, 'B') + r_eps = tvm.reduce_axis((0, alpha), 'r_eps') + r_nu = tvm.reduce_axis((0, alpha), 'r_nu') + V = tvm.compute((alpha, alpha, P // bnb, C, bnb), lambda eps, nu, b, c, bb: + tvm.sum(input_tile[c][b][r_eps][r_nu][bb] * B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), + name='V') + + # batch gemm + c = tvm.reduce_axis((0, C), name='c') + M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b: + tvm.sum(U[eps][nu][k // bna][c][k % bna] * + V[eps][nu][b // bnb][c][b % bnb], axis=c), name='M') + + # inverse transform + A = const_array(A_data, 'A') + r_eps = tvm.reduce_axis((0, alpha), 'r_eps') + r_nu = tvm.reduce_axis((0, alpha), 'r_nu') + Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw: + tvm.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw], axis=[r_eps, r_nu]), + name='Y') + + # unpack output + output = tvm.compute((N, K, H, W), lambda n, k, h, w: + Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m], + name='output', tag='winograd_conv_output') + + return output + +def _schedule_winograd(s, op): + """schedule the winograd fast convolution F(2x2, 3x3) for conv2d""" + + # get ops and tensors + output = op.output(0) + + Y = op.input_tensors[0] + M, A = s[Y].op.input_tensors + U, V = s[M].op.input_tensors + kernel, G = s[U].op.input_tensors + d, B = s[V].op.input_tensors + data_pad = s[d].op.input_tensors[0] + data = s[data_pad].op.input_tensors[0] + + # padding + s[data_pad].compute_inline() + + # pack input tiles + c, b, eps, nu, bb = s[d].op.axis + s[d].reorder(eps, nu, bb) + aha = s[d].fuse(eps, nu) + s[d].unroll(bb) + tile_and_bind3d(s, d, c, b, aha, 4, 1, 1) + + # transform kernel + s[G].compute_inline() + eps, nu, k, c, kk, = s[U].op.axis + r_kh, r_kw = s[U].op.reduce_axis + s[U].reorder(k, c, kk, eps, nu, r_kh, r_kw) + [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]] + s[U].vectorize(kk) + tile_and_bind(s, U, k, c, 1, 256) + + # transform image + s[B].compute_inline() + eps, nu, b, c, bb = s[V].op.axis + r_eps, r_nu = s[V].op.reduce_axis + s[V].reorder(b, c, bb, eps, nu, r_nu, r_eps) + [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]] + s[V].vectorize(bb) + tile_and_bind(s, V, b, c, 2, 1) + + # batch gemm + bna, bnb = 4, 4 + if data.dtype == 'float16' and util.get_const_int(M.shape[3]) % (bnb * 2) == 0: + bnb *= 2 + + eps, nu, k, b = s[M].op.axis + c = s[M].op.reduce_axis[0] + yo, xo, yi, xi = s[M].tile(k, b, bna, bnb) + s[M].reorder(c, yi, xi) + c, c_unroll = s[M].split(c, 2) + s[M].unroll(c_unroll) + s[M].unroll(yi) + s[M].vectorize(xi) + z = s[M].fuse(eps, nu) + tile_and_bind3d(s, M, z, yo, xo, 1, 8, 1) + + # inverse transform + s[A].compute_inline() + k, b, vh, vw = s[Y].op.axis + r_eps, r_nu = s[Y].op.reduce_axis + [s[Y].unroll(x) for x in [vh, vw, r_eps, r_nu]] + tile_and_bind(s, Y, k, b, 4, 1) + + # schedule output + if output.op in s.outputs: # no bias + output = output + else: # has bias + s[output].compute_inline() + output = s.outputs[0] + + _, k, h, w = s[output].op.axis + tile_and_bind3d(s, output, k, h, w, 1, 2, 2) From 8bac8ee83a2cbfc196fce31647c7a4e40f1505bb Mon Sep 17 00:00:00 2001 From: Mercy Date: Tue, 13 Feb 2018 22:11:06 +0800 Subject: [PATCH 2/4] fix lint --- topi/python/topi/mali/conv2d.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index 4be4161540f8..8f59040b23a6 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -122,8 +122,8 @@ def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 gemm_factor = 4 - if (kernel_shape[2:4] == (3, 3) and (HPAD, WPAD) == (1, 1) and kernel_shape[0] >= 64 - and data_shape[2] * data_shape[3] // 4 % gemm_factor == 0 and (HSTR, WSTR) == (1, 1)): + if (kernel_shape[2:4] == (3, 3) and (HPAD, WPAD) == (1, 1) and kernel_shape[0] >= 64 and + data_shape[2] * data_shape[3] // 4 % gemm_factor == 0 and (HSTR, WSTR) == (1, 1)): return _decl_winograd(data, kernel, stride, padding, layout, out_dtype) elif kernel_shape[2:4] == (1, 1): return _decl_im2col(data, kernel, stride, padding, layout, out_dtype) @@ -559,7 +559,8 @@ def _decl_winograd(data, kernel, stride, padding, layout, out_dtype): # pack input tile input_tile = tvm.compute((C, P // bnb, alpha, alpha, bnb), lambda c, b, eps, nu, bb: - data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps][(b*bnb+bb) % nW * m + nu], + data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps] + [(b*bnb+bb) % nW * m + nu], name='d') # transform kernel @@ -567,34 +568,34 @@ def _decl_winograd(data, kernel, stride, padding, layout, out_dtype): r_kh = tvm.reduce_axis((0, KH), 'r_kh') r_kw = tvm.reduce_axis((0, KW), 'r_kw') U = tvm.compute((alpha, alpha, K // bna, C, bna), lambda eps, nu, k, c, kk: - tvm.sum(kernel[k * bna + kk][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), - name='U') + tvm.sum(kernel[k * bna + kk][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], + axis=[r_kh, r_kw]), name='U') # transform image B = const_array(B_data, 'B') r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_nu = tvm.reduce_axis((0, alpha), 'r_nu') V = tvm.compute((alpha, alpha, P // bnb, C, bnb), lambda eps, nu, b, c, bb: - tvm.sum(input_tile[c][b][r_eps][r_nu][bb] * B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), - name='V') + tvm.sum(input_tile[c][b][r_eps][r_nu][bb] * B[r_eps][eps] * B[r_nu][nu], + axis=[r_eps, r_nu]), name='V') # batch gemm c = tvm.reduce_axis((0, C), name='c') M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b: - tvm.sum(U[eps][nu][k // bna][c][k % bna] * - V[eps][nu][b // bnb][c][b % bnb], axis=c), name='M') + tvm.sum(U[eps][nu][k // bna][c][k % bna] * + V[eps][nu][b // bnb][c][b % bnb], axis=c), name='M') # inverse transform A = const_array(A_data, 'A') r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_nu = tvm.reduce_axis((0, alpha), 'r_nu') Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw: - tvm.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw], axis=[r_eps, r_nu]), - name='Y') + tvm.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw], + axis=[r_eps, r_nu]), name='Y') # unpack output output = tvm.compute((N, K, H, W), lambda n, k, h, w: - Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m], + Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m], name='output', tag='winograd_conv_output') return output @@ -628,7 +629,7 @@ def _schedule_winograd(s, op): eps, nu, k, c, kk, = s[U].op.axis r_kh, r_kw = s[U].op.reduce_axis s[U].reorder(k, c, kk, eps, nu, r_kh, r_kw) - [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]] + _ = [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]] s[U].vectorize(kk) tile_and_bind(s, U, k, c, 1, 256) @@ -637,7 +638,7 @@ def _schedule_winograd(s, op): eps, nu, b, c, bb = s[V].op.axis r_eps, r_nu = s[V].op.reduce_axis s[V].reorder(b, c, bb, eps, nu, r_nu, r_eps) - [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]] + _ = [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]] s[V].vectorize(bb) tile_and_bind(s, V, b, c, 2, 1) @@ -661,7 +662,7 @@ def _schedule_winograd(s, op): s[A].compute_inline() k, b, vh, vw = s[Y].op.axis r_eps, r_nu = s[Y].op.reduce_axis - [s[Y].unroll(x) for x in [vh, vw, r_eps, r_nu]] + _ = [s[Y].unroll(x) for x in [vh, vw, r_eps, r_nu]] tile_and_bind(s, Y, k, b, 4, 1) # schedule output From 68a8ac3e945a730797e07bdea62822a7992d31ac Mon Sep 17 00:00:00 2001 From: Mercy Date: Tue, 13 Feb 2018 23:01:42 +0800 Subject: [PATCH 3/4] add padding --- topi/python/topi/mali/conv2d.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index 8f59040b23a6..fdc181cd64ae 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -114,16 +114,13 @@ def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 out_dtype = data.dtype HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) kernel_shape = util.get_const_tuple(kernel.shape) - data_shape = util.get_const_tuple(data.shape) if isinstance(stride, (tuple, list)): HSTR, WSTR = stride else: HSTR, WSTR = stride, stride - gemm_factor = 4 - if (kernel_shape[2:4] == (3, 3) and (HPAD, WPAD) == (1, 1) and kernel_shape[0] >= 64 and - data_shape[2] * data_shape[3] // 4 % gemm_factor == 0 and (HSTR, WSTR) == (1, 1)): + (HSTR, WSTR) == (1, 1)): return _decl_winograd(data, kernel, stride, padding, layout, out_dtype) elif kernel_shape[2:4] == (1, 1): return _decl_im2col(data, kernel, stride, padding, layout, out_dtype) @@ -552,15 +549,17 @@ def _decl_winograd(data, kernel, stride, padding, layout, out_dtype): P = N * nH * nW bna, bnb = 4, 4 - if data.dtype == 'float16' and P % (bnb * 2) == 0: + if data.dtype == 'float16': bnb *= 2 - assert K % bna == 0 and P % bnb == 0 + P_round = (P + bnb - 1) // bnb * bnb + assert K % bna == 0 and P_round % bnb == 0 # pack input tile - input_tile = tvm.compute((C, P // bnb, alpha, alpha, bnb), + input_tile = tvm.compute((C, P_round // bnb, alpha, alpha, bnb), lambda c, b, eps, nu, bb: - data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps] - [(b*bnb+bb) % nW * m + nu], + tvm.select(b * bnb + bb < P,\ + data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps]\ + [(b*bnb+bb) % nW * m + nu], tvm.const(0, data_pad.dtype)), name='d') # transform kernel @@ -575,13 +574,13 @@ def _decl_winograd(data, kernel, stride, padding, layout, out_dtype): B = const_array(B_data, 'B') r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - V = tvm.compute((alpha, alpha, P // bnb, C, bnb), lambda eps, nu, b, c, bb: + V = tvm.compute((alpha, alpha, P_round // bnb, C, bnb), lambda eps, nu, b, c, bb: tvm.sum(input_tile[c][b][r_eps][r_nu][bb] * B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V') # batch gemm c = tvm.reduce_axis((0, C), name='c') - M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b: + M = tvm.compute((alpha, alpha, K, P_round), lambda eps, nu, k, b: tvm.sum(U[eps][nu][k // bna][c][k % bna] * V[eps][nu][b // bnb][c][b % bnb], axis=c), name='M') @@ -595,7 +594,10 @@ def _decl_winograd(data, kernel, stride, padding, layout, out_dtype): # unpack output output = tvm.compute((N, K, H, W), lambda n, k, h, w: - Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m], + Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m] + # thw following term is used to make the padding effective, + # otherwise the padding will be eliminated by bound inference + + tvm.const(0, out_dtype) * M[alpha-1][alpha-1][K-1][P_round-1], name='output', tag='winograd_conv_output') return output @@ -644,7 +646,7 @@ def _schedule_winograd(s, op): # batch gemm bna, bnb = 4, 4 - if data.dtype == 'float16' and util.get_const_int(M.shape[3]) % (bnb * 2) == 0: + if data.dtype == 'float16': bnb *= 2 eps, nu, k, b = s[M].op.axis From 4640ed7fa857e5112b062610f4de442004dec164 Mon Sep 17 00:00:00 2001 From: Mercy Date: Wed, 14 Feb 2018 00:07:43 +0800 Subject: [PATCH 4/4] fix comment --- topi/python/topi/mali/conv2d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index fdc181cd64ae..5b4cf5bae6ff 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -506,7 +506,7 @@ def _schedule_im2col_conv2d(s, op): fuse_and_bind(s, output, [n, co, h, w]) def _decl_winograd(data, kernel, stride, padding, layout, out_dtype): - """schedule the winograd fast convolution F(2x2, 3x3) for conv2d""" + """declare winograd fast convolution F(2x2, 3x3) for conv2d""" N, CI, H, W = [util.get_const_int(x) for x in data.shape] CO, CI, KH, KW = [util.get_const_int(x) for x in kernel.shape] HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) @@ -603,7 +603,7 @@ def _decl_winograd(data, kernel, stride, padding, layout, out_dtype): return output def _schedule_winograd(s, op): - """schedule the winograd fast convolution F(2x2, 3x3) for conv2d""" + """schedule winograd fast convolution F(2x2, 3x3) for conv2d""" # get ops and tensors output = op.output(0)