diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 55d6bba2c43f..73a97d2bb33c 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -35,6 +35,8 @@ from ..nn import conv2d_legalize from ..nn.util import get_const_int, get_pad_tuple from ..nn.winograd_util import winograd_transform_matrices +from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \ + schedule_conv2d_spatial_pack_nchw logger = logging.getLogger('topi') @@ -75,8 +77,11 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dt output : tvm.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, - num_tile=2) + if layout == 'NCHW': + return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, + dilation, out_dtype, num_tile=2) + else: + raise ValueError("Unsupported layout {}".format(layout)) @autotvm.register_topi_schedule( @@ -119,7 +124,8 @@ def _callback(op): if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() - _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) + schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, + conv, output, outs[0]) if 'winograd_conv2d_output' in op.tag: output = op.output(0) @@ -132,179 +138,6 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s - -def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, num_tile): - assert layout == "NCHW", "Only support NCHW" - # create workload according to raw arguments - out_dtype = out_dtype or data.dtype - N, CI, IH, IW = get_const_tuple(data.shape) - - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - - if len(kernel.shape) == 4: - pre_packed = False - CO, _, KH, KW = get_const_tuple(kernel.shape) - else: # kernel tensor is pre packed - pre_packed = True - CO, _, KH, KW, VC = get_const_tuple(kernel.shape) - CO = CO * VC - - dilated_kernel_h = (KH - 1) * dilation_h + 1 - dilated_kernel_w = (KW - 1) * dilation_w + 1 - pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( - padding, (dilated_kernel_h, dilated_kernel_w)) - HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - OH = (IH + pad_top + pad_bottom - dilated_kernel_h) // HSTR + 1 - OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 - data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right]) - - # ==================== define configuration space ==================== - n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW) - ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) - - if num_tile == 2: # for arm cpu - co, vc = cfg.define_split('tile_co', co, num_outputs=2) - oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2) - ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2) - elif num_tile == 3: # for mali gpu - co, _, vc = cfg.define_split('tile_co', co, num_outputs=3) - oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3) - ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3) - else: - raise RuntimeError("Invalid num_tile") - - cfg.define_reorder("reorder_0", - [n, co, oh, ow, ci, kh, kw, vh, vw, vc], - policy='candidate', candidate=[ - [n, co, oh, ow, ci, kh, kw, vh, vw, vc], - [n, co, oh, ow, ci, kh, kw, vc, vh, vw]]) - - cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll') - cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec') - - # fallback support - if cfg.is_fallback: - if num_tile == 2: # arm cpu - ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct') - cfg.fallback_with_reference_log(ref_log) - elif num_tile == 3: # mali gpu - ref_log = autotvm.tophub.load_reference_log('mali', 'rk3399', 'conv2d', 'direct') - cfg.fallback_with_reference_log(ref_log) - # ==================================================================== - - VC = cfg["tile_co"].size[-1] - VH = cfg["tile_oh"].size[-1] - VW = cfg["tile_ow"].size[-1] - - kvshape = (CO // VC, CI, KH, KW, VC) - ovshape = (N, CO // VC, OH // VH, OW // VW, VH, VW, VC) - oshape = (N, CO, OH, OW) - - if dilation_h != 1 or dilation_w != 1: - # undilate input data - dvshape = (N, OH // VH, OW // VW, CI, KH, KW, VH, VW) - data_vec = tvm.compute(dvshape, lambda n, h, w, ci, kh, kw, vh, vw: - data_pad[n][ci][(h*VH+vh)*HSTR+kh*dilation_h] - [(w*VW+vw)*WSTR+kw*dilation_w], - name='data_vec_undilated') - else: - dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1) - data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw: - data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], - name='data_vec') - - if pre_packed: - kernel_vec = kernel - else: - kernel_vec = tvm.compute(kvshape, lambda co, ci, kh, kw, vc: - kernel[co*VC+vc][ci][kh][kw], - name='kernel_vec') - - ci = tvm.reduce_axis((0, CI), name='ci') - kh = tvm.reduce_axis((0, KH), name='kh') - kw = tvm.reduce_axis((0, KW), name='kw') - - if dilation_h != 1 or dilation_w != 1: - conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ - tvm.sum(data_vec[n, h, w, ci, kh, kw, vh, vw].astype(out_dtype) * - kernel_vec[co, ci, kh, kw, vc].astype(out_dtype), - axis=[ci, kh, kw]), name='conv') - else: - conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ - tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) * - kernel_vec[co, ci, kh, kw, vc].astype(out_dtype), - axis=[ci, kh, kw]), name='conv') - - output = tvm.compute(oshape, lambda n, co, h, w: - conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], - name='output_unpack', tag='spatial_conv2d_output') - return output - -def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, - conv, output, last): - """schedule implementation""" - n, co, oh, ow, vh, vw, vc = s[conv].op.axis - ci, kh, kw = s[conv].op.reduce_axis - - # schedule conv - cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, ci, kh, kw, vh, vw, vc]) - cfg["ann_reduce"].apply(s, conv, [kh, kw], - axis_lens=[get_const_int(kh.dom.extent), - get_const_int(kw.dom.extent)], - max_unroll=16, - cfg=cfg) - cfg["ann_spatial"].apply(s, conv, [vh, vw, vc], - axis_lens=[cfg['tile_oh'].size[-1], - cfg['tile_ow'].size[-1], - cfg['tile_co'].size[-1]], - max_unroll=16, - cfg=cfg) - - # schedule fusion - n, co, h, w = s[last].op.axis - co, vc = cfg['tile_co'].apply(s, last, co) - oh, vh = cfg['tile_oh'].apply(s, last, h) - ow, vw = cfg['tile_ow'].apply(s, last, w) - s[last].reorder(n, co, oh, ow, vh, vw, vc) - if last != output: - s[output].compute_inline() - cfg["ann_spatial"].apply(s, last, [vh, vw, vc], - axis_lens=[cfg['tile_oh'].size[-1], - cfg['tile_ow'].size[-1], - cfg['tile_co'].size[-1]], - max_unroll=16, - cfg=cfg) - s[conv].compute_at(s[last], ow) - - # mark parallel - p = s[last].fuse(n, co) - s[last].parallel(p) - - if data_vec.op.name == 'data_vec_undilated': - n, h, _, _, _, _, _, _ = s[data_vec].op.axis - else: - n, h, _, _, _, _ = s[data_vec].op.axis - p = s[data_vec].fuse(n, h) - s[data_vec].parallel(p) - - if kernel_vec.op.name == 'kernel_vec': - co, _, _, _, _ = s[kernel_vec].op.axis - if autotvm.GLOBAL_SCOPE.in_tuning: - # kernel packing will be pre-computed during compilation, so we skip - # this part to make tuning records correct - s[kernel_vec].pragma(co, 'debug_skip_region') - else: - s[kernel_vec].parallel(co) - elif kernel_vec.op.name == 'kernel_vec_conv2d_transpose': # for conv2d transpose - co, _, _, _, _ = s[kernel_vec].op.axis - s[kernel_vec].parallel(co) - - return s - - @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd']) def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): """ TOPI compute callback. Use winograd template """ diff --git a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py new file mode 100644 index 000000000000..2066b9a18747 --- /dev/null +++ b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,no-else-return +"""Conv2D spatial pack implementation for ARM CPU""" +from __future__ import absolute_import as _abs +import tvm +from tvm import autotvm +from .. import nn +from ..util import get_const_tuple +from ..nn.util import get_const_int, get_pad_tuple + +def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation, + out_dtype, num_tile): + """compute define for Conv2d Spatial Pack with NCHW layout""" + out_dtype = out_dtype or data.dtype + N, CI, IH, IW = get_const_tuple(data.shape) + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + if len(kernel.shape) == 4: + pre_packed = False + CO, _, KH, KW = get_const_tuple(kernel.shape) + else: # kernel tensor is pre packed + pre_packed = True + CO, _, KH, KW, VC = get_const_tuple(kernel.shape) + CO = CO * VC + + dilated_kernel_h = (KH - 1) * dilation_h + 1 + dilated_kernel_w = (KW - 1) * dilation_w + 1 + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) + OH = (IH + pad_top + pad_bottom - dilated_kernel_h) // HSTR + 1 + OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 + data_pad = nn.pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right]) + + # ==================== define configuration space ==================== + n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW) + ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) + + if num_tile == 2: # for arm cpu + co, vc = cfg.define_split('tile_co', co, num_outputs=2) + oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2) + ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2) + elif num_tile == 3: # for mali gpu + co, _, vc = cfg.define_split('tile_co', co, num_outputs=3) + oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3) + ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3) + else: + raise RuntimeError("Invalid num_tile") + + cfg.define_reorder("reorder_0", + [n, co, oh, ow, ci, kh, kw, vh, vw, vc], + policy='candidate', candidate=[ + [n, co, oh, ow, ci, kh, kw, vh, vw, vc], + [n, co, oh, ow, ci, kh, kw, vc, vh, vw]]) + + cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll') + cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec') + + # fallback support + if cfg.is_fallback: + if num_tile == 2: # arm cpu + ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct') + cfg.fallback_with_reference_log(ref_log) + elif num_tile == 3: # mali gpu + ref_log = autotvm.tophub.load_reference_log('mali', 'rk3399', 'conv2d', 'direct') + cfg.fallback_with_reference_log(ref_log) + # ==================================================================== + + VC = cfg["tile_co"].size[-1] + VH = cfg["tile_oh"].size[-1] + VW = cfg["tile_ow"].size[-1] + + kvshape = (CO // VC, CI, KH, KW, VC) + ovshape = (N, CO // VC, OH // VH, OW // VW, VH, VW, VC) + oshape = (N, CO, OH, OW) + + if dilation_h != 1 or dilation_w != 1: + # undilate input data + dvshape = (N, OH // VH, OW // VW, CI, KH, KW, VH, VW) + data_vec = tvm.compute(dvshape, lambda n, h, w, ci, kh, kw, vh, vw: + data_pad[n][ci][(h*VH+vh)*HSTR+kh*dilation_h] + [(w*VW+vw)*WSTR+kw*dilation_w], + name='data_vec_undilated') + else: + dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1) + data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw: + data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], + name='data_vec') + + if pre_packed: + kernel_vec = kernel + else: + kernel_vec = tvm.compute(kvshape, lambda co, ci, kh, kw, vc: + kernel[co*VC+vc][ci][kh][kw], + name='kernel_vec') + + ci = tvm.reduce_axis((0, CI), name='ci') + kh = tvm.reduce_axis((0, KH), name='kh') + kw = tvm.reduce_axis((0, KW), name='kw') + + if dilation_h != 1 or dilation_w != 1: + conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ + tvm.sum(data_vec[n, h, w, ci, kh, kw, vh, vw].astype(out_dtype) * + kernel_vec[co, ci, kh, kw, vc].astype(out_dtype), + axis=[ci, kh, kw]), name='conv') + else: + conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ + tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) * + kernel_vec[co, ci, kh, kw, vc].astype(out_dtype), + axis=[ci, kh, kw]), name='conv') + + output = tvm.compute(oshape, lambda n, co, h, w: + conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], + name='output_unpack', tag='spatial_conv2d_output') + return output + +def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, + conv, output, last): + """schedule implementation""" + n, co, oh, ow, vh, vw, vc = s[conv].op.axis + ci, kh, kw = s[conv].op.reduce_axis + + # schedule conv + cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, ci, kh, kw, vh, vw, vc]) + cfg["ann_reduce"].apply(s, conv, [kh, kw], + axis_lens=[get_const_int(kh.dom.extent), + get_const_int(kw.dom.extent)], + max_unroll=16, + cfg=cfg) + cfg["ann_spatial"].apply(s, conv, [vh, vw, vc], + axis_lens=[cfg['tile_oh'].size[-1], + cfg['tile_ow'].size[-1], + cfg['tile_co'].size[-1]], + max_unroll=16, + cfg=cfg) + + # schedule fusion + n, co, h, w = s[last].op.axis + co, vc = cfg['tile_co'].apply(s, last, co) + oh, vh = cfg['tile_oh'].apply(s, last, h) + ow, vw = cfg['tile_ow'].apply(s, last, w) + s[last].reorder(n, co, oh, ow, vh, vw, vc) + if last != output: + s[output].compute_inline() + cfg["ann_spatial"].apply(s, last, [vh, vw, vc], + axis_lens=[cfg['tile_oh'].size[-1], + cfg['tile_ow'].size[-1], + cfg['tile_co'].size[-1]], + max_unroll=16, + cfg=cfg) + s[conv].compute_at(s[last], ow) + + # mark parallel + s[last].parallel(co) + + if data_vec.op.name == 'data_vec_undilated': + _, h, _, _, _, _, _, _ = s[data_vec].op.axis + else: + _, h, _, _, _, _ = s[data_vec].op.axis + s[data_vec].parallel(h) + + if kernel_vec.op.name == 'kernel_vec': + co, _, _, _, _ = s[kernel_vec].op.axis + if autotvm.GLOBAL_SCOPE.in_tuning: + # kernel packing will be pre-computed during compilation, so we skip + # this part to make tuning records correct + s[kernel_vec].pragma(co, 'debug_skip_region') + else: + s[kernel_vec].parallel(co) + elif kernel_vec.op.name == 'kernel_vec_conv2d_transpose': # for conv2d transpose + co, _, _, _, _ = s[kernel_vec].op.axis + s[kernel_vec].parallel(co) + + return s diff --git a/topi/python/topi/arm_cpu/conv2d_transpose.py b/topi/python/topi/arm_cpu/conv2d_transpose.py index 99cb046776e5..bd56ef385631 100644 --- a/topi/python/topi/arm_cpu/conv2d_transpose.py +++ b/topi/python/topi/arm_cpu/conv2d_transpose.py @@ -24,7 +24,7 @@ from ..generic import schedule_conv2d_transpose_nchw from ..nn import conv2d_transpose_nchw, dilate, pad, get_pad_tuple from ..util import get_const_tuple, traverse_inline -from .conv2d import _schedule_spatial_pack +from .conv2d_spatial_pack import schedule_conv2d_spatial_pack_nchw @autotvm.task.register_topi_compute(conv2d_transpose_nchw, "arm_cpu", "direct") def conv2d_transpose_nchw_arm(cfg, Input, Filter, strides, padding, out_dtype): @@ -154,7 +154,8 @@ def _callback(op): if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() - _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) + schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, + conv, output, outs[0]) traverse_inline(s, outs[0].op, _callback) return s diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index ddb34628c861..23821273f56d 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -27,7 +27,8 @@ from ..nn.winograd_util import winograd_transform_matrices # reuse some compute declarations from ARM CPU -from ..arm_cpu.conv2d import _decl_spatial_pack, _alter_conv2d_layout_arm +from ..arm_cpu.conv2d import _alter_conv2d_layout_arm +from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nchw @autotvm.register_topi_compute(conv2d, 'mali', ['direct']) @@ -67,8 +68,11 @@ def conv2d_mali(cfg, data, kernel, strides, padding, dilation, layout, out_dtype output : tvm.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, - num_tile=3) + if layout == 'NCHW': + return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, + dilation, out_dtype, num_tile=3) + else: + raise ValueError("Unsupported layout {}".format(layout)) @autotvm.register_topi_schedule(schedule_conv2d_nchw, 'mali', ['direct', 'winograd']) def schedule_conv2d_nchw_mali(cfg, outs):