From cdcd0e1ffeeee8bb23c4abb1eda954d9cd5f5bab Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Mon, 5 Aug 2019 15:04:48 +0300 Subject: [PATCH] Fix kernel_size != 3 on winograd arm/mali. --- python/tvm/autotvm/tophub.py | 8 +- topi/python/topi/arm_cpu/conv2d.py | 61 +++++++++------ topi/python/topi/cuda/conv2d_winograd.py | 76 +++++++++---------- topi/python/topi/mali/conv2d.py | 37 ++++----- topi/python/topi/nn/winograd_util.py | 29 +++++-- .../tests/python/test_topi_conv2d_winograd.py | 10 +-- 6 files changed, 126 insertions(+), 95 deletions(-) diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index 0130384c2e69..555ff71fe603 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -36,13 +36,13 @@ # the version of each package PACKAGE_VERSION = { - 'arm_cpu': "v0.04", + 'arm_cpu': "v0.05", 'llvm': "v0.03", - 'cuda': "v0.04", + 'cuda': "v0.05", 'rocm': "v0.02", - 'opencl': "v0.02", - 'mali': "v0.05", + 'opencl': "v0.03", + 'mali': "v0.06", 'vta': "v0.06", } diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 62df8f9fbfbe..093446073f42 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -32,7 +32,7 @@ conv2d_winograd_nnpack_without_weight_transform, \ depthwise_conv2d_nchw from ..nn.util import get_const_int, get_pad_tuple -from ..nn.winograd_util import winograd_transform_matrices +from ..nn.winograd_util import winograd_transform_matrices, enum_tile_sizes @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct']) def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): @@ -302,11 +302,14 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd']) def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): """ TOPI compute callback. Use winograd template """ - tile_size = 4 - return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, - out_dtype, tile_size) + return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype) + +def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, + tile_size=None): + + cfg.define_knob('tile_size', enum_tile_sizes(data)) + tile_size = tile_size if tile_size else cfg["tile_size"].val -def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size): N, CI, IH, IW = get_const_tuple(data.shape) if isinstance(dilation, int): @@ -326,11 +329,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt CO *= VC KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1 HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + dilated_kernel_h = (KH - 1) * dilation_h + 1 + dilated_kernel_w = (KW - 1) * dilation_w + 1 + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) assert layout == 'NCHW' - assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 - data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") + assert KH == KW and HSTR == 1 and WSTR == 1 + data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right]) r = KW m = tile_size @@ -340,8 +346,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt K = CO C = CI - H = (IH + 2 * HPAD - 3) // HSTR + 1 - W = (IW + 2 * WPAD - 3) // WSTR + 1 + H = (IH + pad_top + pad_bottom - KH) // HSTR + 1 + W = (IW + pad_left + pad_right - KW) // WSTR + 1 nH, nW = (H + m-1) // m, (W + m-1) // m P = N * nH * nW @@ -510,12 +516,15 @@ def conv2d_arm_cpu_winograd_nnpack( assert len(kernel.shape) == 4 CO, _, KH, KW = get_const_tuple(kernel.shape) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + dilated_kernel_h = (KH - 1) * dilation_h + 1 + dilated_kernel_w = (KW - 1) * dilation_w + 1 + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) assert layout == 'NCHW' - assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1 - H = (IH + 2 * HPAD - 3) // HSTR + 1 - W = (IW + 2 * WPAD - 3) // WSTR + 1 + assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 + H = (IH + pad_top + pad_bottom - KH) // HSTR + 1 + W = (IW + pad_left + pad_right - KW) // WSTR + 1 cfg.define_knob('winograd_nnpack_algorithm', [convolution_algorithm]) @@ -530,7 +539,7 @@ def conv2d_arm_cpu_winograd_nnpack( output = tvm.contrib.nnpack.convolution_inference_without_weight_transform( data, transformed_kernel, bias=None, - padding=[HPAD, HPAD, WPAD, WPAD], + padding=[pad_top, pad_bottom, pad_left, pad_right], stride=[HSTR, WSTR], algorithm=cfg['winograd_nnpack_algorithm'].val) @@ -590,13 +599,16 @@ def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides, assert len(transformed_kernel.shape) == 4 CO, _, _, _ = get_const_tuple(transformed_kernel.shape) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - HPAD, WPAD, _, _ = get_pad_tuple(padding, (3, 3)) KH, KW = 3, 3 + dilated_kernel_h = (KH - 1) * dilation_h + 1 + dilated_kernel_w = (KW - 1) * dilation_w + 1 + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) assert layout == 'NCHW' - assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1 - H = (IH + 2 * HPAD - 3) // HSTR + 1 - W = (IW + 2 * WPAD - 3) // WSTR + 1 + assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 + H = (IH + pad_top + pad_bottom - KH) // HSTR + 1 + W = (IW + pad_left + pad_right - KW) // WSTR + 1 assert N == 1 with tvm.tag_scope("winograd_nnpack_conv2d_output"): @@ -604,7 +616,7 @@ def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides, data=data, transformed_kernel=transformed_kernel, bias=bias, - padding=[HPAD, HPAD, WPAD, WPAD], + padding=[pad_top, pad_bottom, pad_left, pad_right], stride=[HSTR, WSTR], algorithm=cfg['winograd_nnpack_algorithm'].val) @@ -701,14 +713,16 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): dispatch_ctx.update(target, new_workload, cfg) return F.nn.conv2d(*copy_inputs, **new_attrs) + elif cfg.template_key == "winograd": # pre-compute weight transformation in winograd + # safe default tile_size 2,4 + # 4 (if winograd polynomial not excessive) if "-device=arm_cpu" in target.options: - tile_size = 4 VC = cfg['tile_k'].size[-1] + tile_size = 4 if KH < 4 else 2 else: - from ..mali.conv2d import _pick_tile_size - tile_size = _pick_tile_size(tinfos[0], tinfos[1]) VC = cfg['tile_bna'].val + tile_size = 4 if (H % 4 == 0) and (KH < 4) else 2 weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size) @@ -730,6 +744,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): dispatch_ctx.update(target, new_workload, cfg) return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs) + elif cfg.template_key in ["winograd_nnpack_fp16", "winograd_nnpack_fp32"]: # pre-compute winograd_nnpack transform # for winograd_nnpack_fp16, the the precomputeprune pass must run on device, diff --git a/topi/python/topi/cuda/conv2d_winograd.py b/topi/python/topi/cuda/conv2d_winograd.py index 29f14a0f708e..e3580e721386 100644 --- a/topi/python/topi/cuda/conv2d_winograd.py +++ b/topi/python/topi/cuda/conv2d_winograd.py @@ -20,28 +20,25 @@ import tvm from tvm import autotvm -from .. import nn -from ..nn import conv2d, group_conv2d_nchw, conv2d_winograd_without_weight_transform +from ..nn import pad, conv2d, conv2d_alter_layout, group_conv2d_nchw, \ + conv2d_winograd_without_weight_transform from ..util import get_const_int, get_const_tuple, traverse_inline +from ..nn.util import get_pad_tuple from ..generic import schedule_conv2d_winograd_without_weight_transform -from ..nn.winograd_util import winograd_transform_matrices +from ..nn.winograd_util import winograd_transform_matrices, enum_tile_sizes -def _infer_tile_size(data, kernel): - N, CI, H, W = get_const_tuple(data.shape) - - if H % 8 == 0: - return 4 - return 2 - -def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, pre_computed): +def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size=None, + pre_computed=False): """Compute declaration for winograd""" assert layout == 'NCHW' - tile_size = _infer_tile_size(data, kernel) + cfg.define_knob('tile_size', enum_tile_sizes(data)) + tile_size = tile_size if tile_size else cfg["tile_size"].val N, CI, H, W = get_const_tuple(data.shape) + if not pre_computed: # kernel tensor is raw tensor, do strict check if isinstance(dilation, int): dilation_h = dilation_w = dilation @@ -51,18 +48,18 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty kernel = dilate(kernel, (1, 1, dilation_h, dilation_w)) CO, CI, KH, KW = get_const_tuple(kernel.shape) - HPAD, WPAD, _, _ = nn.get_pad_tuple(padding, kernel) + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides assert HSTR == 1 and WSTR == 1 and KH == KW - else: # kernel tensor is pre-transfomred. this op is created by - # alter op layout, do not check + else: # kernel tensor is pre-transformed. + # this op is created by alter op layout, do not check # dilation is not supported HSTR = WSTR = 1 HPAD = WPAD = 1 KH = KW = 3 _, _, CI, CO = get_const_tuple(kernel.shape) - data_pad = nn.pad(data, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad") + data_pad = pad(data, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad") r = KW m = tile_size @@ -106,7 +103,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty # inverse transform r_a = tvm.reduce_axis((0, alpha), 'r_a') - r_b = tvm.reduce_axis((0, alpha), 'r_a') + r_b = tvm.reduce_axis((0, alpha), 'r_b') inverse = tvm.compute((CO, P, m, m), lambda co, p, vh, vw: tvm.sum(bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]), name='inverse') @@ -280,7 +277,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed): ['cuda', 'gpu'], ['winograd']) def conv2d_winograd_ww(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size): return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, - pre_computed=True) + tile_size, pre_computed=True) @autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform, @@ -298,7 +295,7 @@ def _callback(op): ##### REGISTER ALTER OP LAYOUT ##### -@nn.conv2d_alter_layout.register(["cuda", "gpu"]) +@conv2d_alter_layout.register(["cuda", "gpu"]) def _alter_conv2d_layout(attrs, inputs, tinfos, F): """Alter op layout for pre-computing kernel transformation @@ -382,25 +379,28 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F): warnings.warn("Does not support weight pre-transform for dilated convolution.") return None - # pre-compute weight transformation in winograd - tile_size = _infer_tile_size(tinfos[0], tinfos[1]) - - weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], - tile_size=tile_size) - weight = F.transpose(weight, axes=[0, 1, 3, 2]) - copy_inputs[1] = weight - new_attrs['tile_size'] = tile_size - - # Store the same config for the altered operator (workload) - new_data = data - new_weight = tvm.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO), - dtype=kernel.dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_weight, strides, padding, dilation, layout, out_dtype, tile_size], - conv2d_winograd_without_weight_transform - ) - dispatch_ctx.update(target, new_workload, cfg) - return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs) + if cfg.template_key == "winograd": + # safe default tile_size + # 4 (if winograd polynomial not excessive) + tile_size = 4 if (H % 8 == 0) and (KH < 4) else 2 + new_attrs['tile_size'] = tile_size + + weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], + tile_size=tile_size) + weight = F.transpose(weight, axes=[0, 1, 3, 2]) + copy_inputs[1] = weight + + # Store the same config for the altered operator (workload) + new_data = data + new_weight = tvm.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO), + dtype=kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, layout, out_dtype, tile_size], + conv2d_winograd_without_weight_transform + ) + dispatch_ctx.update(target, new_workload, cfg) + return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs) + if groups != CI: workload = autotvm.task.args_to_workload( [tinfos[0], tinfos[1], strides, padding, dilation, groups, out_dtype], diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index ddb34628c861..88b7f1dedbb0 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -24,7 +24,7 @@ from ..util import traverse_inline, get_const_int, get_const_tuple from ..nn import conv2d, conv2d_winograd_without_weight_transform, \ get_pad_tuple, pad, conv2d_alter_layout -from ..nn.winograd_util import winograd_transform_matrices +from ..nn.winograd_util import winograd_transform_matrices, enum_tile_sizes # reuse some compute declarations from ARM CPU from ..arm_cpu.conv2d import _decl_spatial_pack, _alter_conv2d_layout_arm @@ -188,21 +188,16 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): return s ##### WINOGRAD TEMPLATE ##### -def _pick_tile_size(data, kernel): - N, CI, H, W = get_const_tuple(data.shape) - - if H % 4 == 0: - return 4 - else: - return 2 - @autotvm.register_topi_compute(conv2d, 'mali', ['winograd']) def conv2d_mali_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): - tile_size = _pick_tile_size(data, kernel) - return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, - tile_size) + return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype) + +def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, + tile_size=None): + + cfg.define_knob('tile_size', enum_tile_sizes(data)) + tile_size = tile_size if tile_size else cfg["tile_size"].val -def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size): N, CI, IH, IW = get_const_tuple(data.shape) if isinstance(dilation, int): dilation_h = dilation_w = dilation @@ -222,19 +217,22 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt CO *= VC KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1 HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + dilated_kernel_h = (KH - 1) * dilation_h + 1 + dilated_kernel_w = (KW - 1) * dilation_w + 1 + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) assert layout == 'NCHW' - assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 - data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") + assert KH == KW and HSTR == 1 and WSTR == 1 + data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right]) r = KW m = tile_size alpha = m + r - 1 A, B, G = winograd_transform_matrices(m, r, out_dtype) - H = (IH + 2 * HPAD - 3) // HSTR + 1 - W = (IW + 2 * WPAD - 3) // WSTR + 1 + H = (IH + pad_top + pad_bottom - KH) // HSTR + 1 + W = (IW + pad_left + pad_right - KW) // WSTR + 1 nH, nW = (H + m-1) // m, (W + m-1) // m P = N * nH * nW @@ -250,6 +248,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ##### space definition end ##### if cfg.is_fallback: + # safe default tile_size always 2,4 + # 4 (if winograd polynomial not excessive) + cfg['tile_size'].val = 4 if (IH % 4 == 0) and (KH < 4) else 2 cfg['tile_bnb'].val = 4 cfg['tile_bna'].val = 4 while CO % cfg['tile_bna'].val != 0: diff --git a/topi/python/topi/nn/winograd_util.py b/topi/python/topi/nn/winograd_util.py index 464b63301b40..2c08e5cf4353 100644 --- a/topi/python/topi/nn/winograd_util.py +++ b/topi/python/topi/nn/winograd_util.py @@ -15,24 +15,26 @@ # specific language governing permissions and limitations # under the License. # -""" Utility functions for implementing Winograd convolutions - [*] Fast Algorithms for Convolutional Neural Networks - Andrew Lavin, Scott Gray - https://arxiv.org/abs/1509.09308 - https://github.com/andravin/wincnn -""" +""" Utility functions for implementing Winograd convolutions""" from operator import mul from functools import reduce import numpy as np from tvm.contrib.pickle_memoize import memoize -from ..util import const_matrix +from ..util import const_matrix, get_const_tuple # pylint: disable=invalid-name def _cook_toom_convolution(a, n, r): """Compute Cook-Toom convolution A,B,G matrices""" + # + # [*] Fast Algorithms for Convolutional Neural Networks + # Andrew Lavin, Scott Gray + # https://arxiv.org/abs/1509.09308 + # https://github.com/andravin/wincnn + # + def _F_m(a, n): f = lambda j, i: reduce(mul, ((a[i]-a[k] if k != i else 1) for k in range(0, n-1)), 1) F = np.fromfunction(np.vectorize(f), (1, n-1), dtype=int) @@ -152,3 +154,16 @@ def winograd_transform_matrices(tile_size, kernel_size, out_dtype): const_matrix(B_data.astype(out_dtype), "B"), const_matrix(G_data.astype(out_dtype), "G"), ) + +def enum_tile_sizes(data): + """Propose tile sizes for winograd convolution""" + _, _, H, _ = get_const_tuple(data.shape) + + sizes = [2,] + # 2 always present + for s in range(3, 9): + # overlap less then half + if H % s < (H//2): + sizes.append(s) + + return sizes diff --git a/topi/tests/python/test_topi_conv2d_winograd.py b/topi/tests/python/test_topi_conv2d_winograd.py index a42d61d8cb99..d821b1b49c40 100644 --- a/topi/tests/python/test_topi_conv2d_winograd.py +++ b/topi/tests/python/test_topi_conv2d_winograd.py @@ -83,7 +83,7 @@ def check_device(device): func(a, w, c) rtol = 1e-5 - if (kernel > 3): + if (kernel > 5): rtol = 2e-5 tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol) @@ -110,16 +110,16 @@ def test_conv2d_nchw(): with WinogradFallback(): # inception v3 workloads - verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3, devices=['cuda']) - verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3, devices=['cuda']) - verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3, devices=['cuda']) + verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3) + verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3) + verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3) # resnet 18 workloads verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) - verify_conv2d_nchw(1, 48, 35, 64, 5, 1, 2, devices=['cuda']) + verify_conv2d_nchw(1, 48, 35, 64, 5, 1, 2) # batch size = 2 verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1)