From b1b6fc30994d82dbed00cb1afe30e1b49da5dbca Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 30 Oct 2018 15:29:11 -0700 Subject: [PATCH 1/9] add x86 depthwise_conv2d NCHWc --- nnvm/python/nnvm/top/nn.py | 11 +- nnvm/src/top/nn/convolution.cc | 4 +- topi/python/topi/generic/nn.py | 18 +++ topi/python/topi/nn/depthwise_conv2d.py | 59 +++++++ topi/python/topi/x86/__init__.py | 1 + topi/python/topi/x86/conv2d.py | 84 ++++++---- topi/python/topi/x86/depthwise_conv2d.py | 150 ++++++++++++++++++ topi/python/topi/x86/util.py | 13 ++ .../python/test_topi_depthwise_conv2d.py | 105 ++++++++++++ 9 files changed, 412 insertions(+), 33 deletions(-) create mode 100644 topi/python/topi/x86/depthwise_conv2d.py create mode 100644 topi/python/topi/x86/util.py diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index a4b36ea853d5..23101531c885 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -4,7 +4,7 @@ import tvm import topi -from topi.util import get_const_int +from topi.util import get_const_int, get_const_tuple from .tensor import _fschedule_broadcast, _fschedule_injective from . import registry as reg from .registry import OpPattern @@ -167,16 +167,22 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): padding = attrs.get_int_tuple("padding") strides = attrs.get_int_tuple("strides") dilation = attrs.get_int_tuple("dilation") + channels = attrs.get_int("channels") groups = attrs.get_int("groups") layout = attrs.get_string("layout") out_layout = attrs.get_string("out_layout") out_dtype = attrs.get_string("out_dtype") out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype + _, in_channel_chunk, _, _, in_channel_block = get_const_tuple(inputs[0].shape) + in_channel = in_channel_chunk * in_channel_block assert dilation == (1, 1), "not support dilate now" if groups == 1: # pylint: disable=assignment-from-no-return out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding, layout, out_layout, out_dtype) + elif groups == in_channel and groups == channels: + out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, + layout, out_layout, out_dtype) # pylint: enable=assignment-from-no-return else: raise ValueError("not support arbitrary group number > 1 for now") @@ -190,9 +196,12 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): def schedule_contrib_conv2d_NCHWc(attrs, outs, target): """Schedule definition of conv2d NCHWc""" groups = attrs.get_int("groups") + channels = attrs.get_int("channels") with tvm.target.create(target): if groups == 1: return topi.generic.schedule_conv2d_NCHWc(outs) + elif groups == channels: + return topi.generic.schedule_depthwise_conv2d_NCHWc(outs) else: raise ValueError("not support group number > 1 for now") diff --git a/nnvm/src/top/nn/convolution.cc b/nnvm/src/top/nn/convolution.cc index 22bda048a0a2..813947492117 100644 --- a/nnvm/src/top/nn/convolution.cc +++ b/nnvm/src/top/nn/convolution.cc @@ -82,7 +82,9 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, wshape[kernel_layout.indexof('O')] *= param.groups; - NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape); + if (in_shape->at(Conv2DParam::kWeight).ndim() == 0) { + NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape); + } if (param.use_bias) { static const Layout default_bias_layout("C"); TShape bias_shape({param.channels}); diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 765b48d286bc..f452a8fcb1d5 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -191,6 +191,24 @@ def schedule_depthwise_conv2d_nhwc(outs): """ return _default_schedule(outs, False) + +@tvm.target.generic_func +def schedule_depthwise_conv2d_NCHWc(outs): + """Schedule for depthwise_conv2d_NCHWc + Parameters + ---------- + outs: Array of Tensor + The computation graph description of depthwise_conv2d_nhwc + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + @tvm.target.generic_func def schedule_bitserial_conv2d_nchw(outs): """Schedule for bitserial_conv2d_nchw diff --git a/topi/python/topi/nn/depthwise_conv2d.py b/topi/python/topi/nn/depthwise_conv2d.py index c7906d3a4373..8782c3e30a5f 100644 --- a/topi/python/topi/nn/depthwise_conv2d.py +++ b/topi/python/topi/nn/depthwise_conv2d.py @@ -1,6 +1,7 @@ # pylint: disable=invalid-name, unused-variable, too-many-locals """Depthwise convolution operators""" from __future__ import absolute_import as _abs +from collections import namedtuple import tvm from .dilate import dilate @@ -8,6 +9,27 @@ from .util import get_pad_tuple from ..util import simplify +# workload description of depthwise-conv2d +Workload = namedtuple('Workload', + ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + +def _get_workload(data, kernel, stride, padding, out_dtype): + """ Get the workload structure. """ + _, in_channel, height, width = [x.value for x in data.shape] + channel, channel_multiplier, kh, kw = [x.value for x in kernel.shape] + out_channel = channel * channel_multiplier + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \ + "Do not support inputs with different data types now. ' \ + '{} vs. {}".format(data.dtype, kernel.dtype) + return Workload(data.dtype, out_dtype, height, width, in_channel, + out_channel, kh, kw, HPAD, WPAD, HSTR, WSTR) + @tvm.target.generic_func def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype=None): @@ -232,3 +254,40 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid tag='depthwise_conv2d_backward_weight_nhwc') return Weight_grad + + +@tvm.target.generic_func +def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, layout, out_layout, out_dtype=None): + """Depthwise convolution NCHW[x]c forward operator. + + Parameters + ---------- + Input : tvm.Tensor + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + Filter : tvm.Tensor + 4-D with shape [out_channel_chunk, filter_height, filter_width, out_channel_block] + In NCHWc depthwise convolution, + we group kernel's in_channel and channel_multiplier together then do the tiling. + + stride : tuple of two ints + The spatial stride along height and width + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + layout : str + Input data layout + + out_layout : str + Output data layout + + out_dtype: str, optional + Output data type + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc") diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index c146419fcec9..9e0e94e6cd2d 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -9,3 +9,4 @@ from .injective import * from .pooling import schedule_pool, schedule_global_pool from .bitserial_conv2d import schedule_bitserial_conv2d +from .depthwise_conv2d import schedule_depthwise_conv2d_NCHWc diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 3dc6d5e4bab8..1e613fde2fd2 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -7,21 +7,15 @@ from .. import generic, tag from .. import nn from ..util import get_const_tuple -from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, _get_workload +from ..nn.conv2d import conv2d, conv2d_NCHWc, \ + conv2d_alter_layout, _get_workload as _get_conv2d_workload +from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload +from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc from ..nn.pad import pad +from .util import get_fp32_len from . import conv2d_avx_1x1, conv2d_avx_common -def _get_fp32_len(): - fp32_vec_len = 8 - target = tvm.target.current_target() - if target is not None: - for opt in target.options: - if opt == '-mcpu=skylake-avx512': - fp32_vec_len = 16 - return fp32_vec_len - - def _get_default_config(cfg, workload): """ Get default schedule config for the workload @@ -30,7 +24,7 @@ def _get_default_config(cfg, workload): workload : topi.nn.conv2d.Workload Convolution workload """ - fp32_vec_len = _get_fp32_len() + fp32_vec_len = get_fp32_len() is_kernel_1x1 = workload.hkernel == 1 and workload.wkernel == 1 if is_kernel_1x1: conv2d_avx_1x1._fallback_schedule(cfg, workload, fp32_vec_len) @@ -72,7 +66,7 @@ def _declaration_conv(cfg, data, kernel, strides, padding, layout, out_dtype): if layout == 'NCHW': _create_tuning_space(cfg, data, kernel, strides, padding, layout) if cfg.is_fallback: - wkl = _get_workload(data, kernel, strides, padding, out_dtype) + wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) _get_default_config(cfg, wkl) return _declaration_conv_impl(cfg, data, kernel, strides, padding, layout, out_dtype) elif layout == 'HWCN': @@ -283,18 +277,24 @@ def _alter_conv2d_layout(attrs, inputs, tinfo): copy_inputs = [s for s in inputs] new_attrs = {k : attrs[k] for k in attrs.keys()} data, kernel = tinfo[0], tinfo[1] - # only optimize for NCHW, groups=1 conv - if attrs['layout'] != 'NCHW' or attrs.get_int("groups") != 1: - return None batch_size, in_channel, height, width = get_const_tuple(data.shape) - out_channel, _, kh, kw = get_const_tuple(kernel.shape) + groups = attrs.get_int("groups") + out_channel = attrs.get_int("channels") padding = attrs.get_int_tuple("padding") strides = attrs.get_int_tuple("strides") layout = attrs['layout'] + kh, kw = attrs.get_int_tuple("kernel_size") dtype = data.dtype out_dtype = dtype if attrs["out_dtype"] == "same" else attrs["out_dtype"] + is_depthwise = groups == in_channel and groups == out_channel + + # only optimize for NCHW + if layout != 'NCHW': + return None + if groups != 1 and not is_depthwise: + return None workload = autotvm.task.args_to_workload( [data, kernel, strides, padding, layout, out_dtype], conv2d) @@ -302,25 +302,47 @@ def _alter_conv2d_layout(attrs, inputs, tinfo): target = tvm.target.current_target() cfg = dispatch_ctx.query(target, workload) if cfg.is_fallback: - wkl = _get_workload(data, kernel, strides, padding, out_dtype) - _get_default_config(cfg, wkl) + if is_depthwise: + wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype) + from depthwise_conv2d import fallback_schedule + fallback_schedule(cfg, wkl, get_fp32_len()) + else: + wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) + _get_default_config(cfg, wkl) ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] new_attrs['layout'] = 'NCHW%dc' % ic_bn new_attrs['out_layout'] = 'NCHW%dc' % oc_bn - # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) - new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) - # Store the same config for the altered operator (workload) new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), dtype=data.dtype) - new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn), - dtype=kernel.dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, new_attrs['layout'], - new_attrs['out_layout'], out_dtype], conv2d_NCHWc) - dispatch_ctx.update(target, new_workload, cfg) + if is_depthwise: + # channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block + kernel_sym = copy_inputs[1] + kernel_sym = sym.transpose(kernel_sym, axes=(2, 3, 0, 1)) + kernel_sym = sym.reshape(kernel_sym, shape=(kh, kw, out_channel)) + kernel_sym = sym.reshape(kernel_sym, shape=(kh, kw, out_channel//oc_bn, oc_bn)) + kernel_sym = sym.transpose(kernel_sym, axes=(2, 0, 1, 3)) + copy_inputs[1] = kernel_sym + + # Store altered operator's config + new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn), dtype=kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, new_attrs['layout'], + new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc) + else: + out_channel, _, kh, kw = get_const_tuple(kernel.shape) + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + + # Store altered operator's config + new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn), + dtype=kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, new_attrs['layout'], + new_attrs['out_layout'], out_dtype], conv2d_NCHWc) + dispatch_ctx.update(target, new_workload, cfg) return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) @@ -341,9 +363,9 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, num_filter = oc_chunk * oc_bn # get workload and related schedule config - wkl = _get_workload(tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), - tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), - dtype=kernel.dtype), + wkl = _get_conv2d_workload(tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), + tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), + dtype=kernel.dtype), strides, padding, out_dtype) if cfg.is_fallback: _get_default_config(cfg, wkl) diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py new file mode 100644 index 000000000000..e6bf342c0420 --- /dev/null +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -0,0 +1,150 @@ +import tvm +from tvm import autotvm +from tvm.autotvm.task.space import SplitEntity +from .. import generic, tag +from ..nn.pad import pad +from ..util import get_const_tuple +from ..nn.util import get_pad_tuple +from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload + +from .util import get_fp32_len + +def fallback_schedule(cfg, wkl, simd_width): + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + oc_bn = 1 + for bn in range(simd_width, 0, -1): + if wkl.out_filter % bn == 0: + oc_bn = bn + break + + ic_bn = 1 + for bn in range(oc_bn, 0, -1): + if wkl.in_filter % bn == 0: + ic_bn = bn + break + + reg_n = 1 + for n in range(31, 0, -1): + if out_width % n == 0: + reg_n = n + break + + cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) + cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) + cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) + + +@autotvm.register_topi_compute(depthwise_conv2d_NCHWc, 'cpu', 'direct') +def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, + layout, out_layout, out_dtype=None): + out_dtype = data.dtype if out_dtype is None else out_dtype + batch, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple(data.shape) + out_channel_chunk, filter_height, filter_width, out_channel_block = get_const_tuple(kernel.shape) + + strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) + HSTR, WSTR = strides + pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (filter_height, filter_width)) + + in_channel = in_channel_chunk * in_channel_block + out_channel = out_channel_chunk * out_channel_block + channel_multiplier = out_channel // in_channel + + out_height = (in_height - filter_height + pad_top + pad_down) // HSTR + 1 + out_width = (in_width - filter_width + pad_left + pad_right) // WSTR + 1 + + # get workload and related schedule config + wkl = _get_workload(tvm.placeholder((batch, in_channel, in_height, in_width), dtype=data.dtype), + tvm.placeholder((out_channel, in_channel, filter_height, filter_width), + dtype=kernel.dtype), + strides, padding, out_dtype) + if cfg.is_fallback: + fallback_schedule(cfg, wkl, get_fp32_len()) + + # padding stage + DOPAD = (pad_top != 0 or pad_left != 0 or pad_down != 0 or pad_right != 0) + if DOPAD: + pad_before = [0, 0, pad_top, pad_left, 0] + pad_after = [0, 0, pad_down, pad_right, 0] + data_pad = pad(data, pad_before, pad_after, name="PaddedInput") + else: + data_pad = data + + # depthconv stage + di = tvm.reduce_axis((0, filter_height), name='di') + dj = tvm.reduce_axis((0, filter_width), name='dj') + Output = tvm.compute( + (batch, out_channel_chunk, out_height, out_width, out_channel_block), + lambda b, oco, i, j, oci: tvm.sum( + (data_pad[b, (oco * out_channel_block + oci) // channel_multiplier // in_channel_block, + i*HSTR+di, j*WSTR+dj, + ((oco * out_channel_block + oci) // channel_multiplier) % in_channel_block] + .astype(out_dtype) * + kernel[oco, di, dj, oci].astype(out_dtype)), + axis=[di, dj]), + name='DepthwiseConv2d', tag="depthwise_conv2d_NCHWc") + return Output + + +@autotvm.register_topi_schedule(generic.schedule_depthwise_conv2d_NCHWc, 'cpu', ['direct']) +def schedule_depthwise_conv2d_NCHWc(cfg, outs): + s = tvm.create_schedule([x.op for x in outs]) + scheduled_ops = [] + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + if 'depthwise_conv2d_NCHWc' in op.tag: + conv_out = op.output(0) + data = conv_out.op.input_tensors[0] + input = data + kernel = conv_out.op.input_tensors[1] + _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, input, kernel, conv_out, outs[0]) + scheduled_ops.append(op) + traverse(outs[0].op) + return s + +def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data, kernel, conv_out, output): + tile_ow = cfg["tile_ow"].size[-1] + # schedule data + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ic_chunk, ih, iw, ic_block = s[A].op.axis + p = s[A].fuse(ic_chunk, ih) + s[A].parallel(p) + + C, O = conv_out, output + CC = s.cache_write(C, 'global') + + _, ic_chunk, oh, ow, ic_block = s[C].op.axis + ow_chunk, ow_block = s[C].split(ow, factor=tile_ow) + s[C].reorder(ic_chunk, oh, ow_chunk, ow_block, ic_block) + s[C].vectorize(ic_block) + parallel_axis = s[C].fuse(ic_chunk, oh) + s[C].parallel(parallel_axis) + s[C].unroll(ow_block) + s[CC].compute_at(s[C], ow_chunk) + + _, ic_chunk, oh, ow, ic_block = s[CC].op.axis + kh, kw = s[CC].op.reduce_axis + ow_chunk, ow_block = s[CC].split(ow, factor=tile_ow) + s[CC].reorder(ic_chunk, oh, kh, kw, ow_block, ic_block) + s[CC].vectorize(ic_block) + s[CC].unroll(ow_block) + if C != O: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=tile_ow) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(oc_chunk, oh) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + return s diff --git a/topi/python/topi/x86/util.py b/topi/python/topi/x86/util.py new file mode 100644 index 000000000000..d7117944cf40 --- /dev/null +++ b/topi/python/topi/x86/util.py @@ -0,0 +1,13 @@ +# pylint: disable=invalid-name +"""Common x86 related utilities""" +from __future__ import absolute_import as _abs +import tvm + +def get_fp32_len(): + fp32_vec_len = 8 + target = tvm.target.current_target() + if target is not None: + for opt in target.options: + if opt == '-mcpu=skylake-avx512': + fp32_vec_len = 16 + return fp32_vec_len \ No newline at end of file diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py index 51f2c418c121..06cf5f729a90 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d.py +++ b/topi/tests/python/test_topi_depthwise_conv2d.py @@ -206,6 +206,111 @@ def get_ref_data(): check_device(device) +def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1): + def _transform_data(data, bn): + # NCHW -> NCHW[x]c + batch_size, channel, height, width = data.shape + data = np.transpose(data, (0, 2, 3, 1)) + data = np.reshape(data, (batch_size, height, width, channel//bn, bn)) + data = np.transpose(data, (0, 3, 1, 2, 4)) + return data + + def _transform_kernel(kernel, bn): + # channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block + channel, channel_multiplier, kh, kw = kernel.shape + out_channel = channel * channel_multiplier + kernel = np.transpose(kernel, axes=(2, 3, 0, 1)) + kernel = np.reshape(kernel, shape=(kh, kw, out_channel)) + kernel = np.reshape(kernel, shape=(kh, kw, out_channel//bn, bn)) + kernel = np.transpose(kernel, axes=(2, 0, 1, 3)) + return kernel + + in_width = in_height + filter_channel = in_channel + filter_width = filter_height + stride_h = stride_w = stride + + assert dilation == 1, "depthwise_conv2d_NCHWc currently does not support dilation." + pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width)) + padding_args = (pad_h, pad_w) + + out_channel = filter_channel * channel_multiplier + # for testing functionality, + # we choose arbitrary block size that can divide the channel, + # regardless of the performance. + oc_block = 1 + for bn in range(16, 0, -1): + if num_filter % bn == 0: + oc_block = bn + break + + ic_block = 1 + for bn in range(oc_block, 0, -1): + if in_channel % bn == 0: + ic_block = bn + break + + # placeholder + Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input') + Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter') + + dtype = 'float32' + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + # declare + DepthwiseConv2d = topi.nn.depthwise_conv2d_NCHWc(Input, Filter, + (stride_h, stride_w), + padding_args, dtype) + # TODO: add scale_shift for NCHWc and add test here + Relu = topi.nn.relu(DepthwiseConv2d) + # schedule + s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d) + s2 = topi.generic.schedule_depthwise_conv2d_nchw(Relu) + # build the kernels + f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) + f2 = tvm.build(s2, [Input, Filter, Relu], device) + + # Prepare pod type for test data closure + input_shape = get_const_tuple(Input.shape) + filter_shape = get_const_tuple(Filter.shape) + + # Use memoize, pickle the test data for next time use. + @memoize("topi.tests.test_topi_depthwise_conv2d.nchw") + def get_ref_data(): + input_np = np.random.uniform(size=input_shape).astype(dtype) + filter_np = np.random.uniform(size=filter_shape).astype(dtype) + # correctness with scipy + depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw( + input_np, filter_np, stride, padding) + relu_scipy = np.maximum(depthwise_conv2d_scipy, 0) + return (input_np, filter_np, depthwise_conv2d_scipy, relu_scipy) + # Get the test data + (input_np, filter_np, depthwise_conv2d_scipy, relu_scipy) = get_ref_data() + + input_tvm = tvm.nd.array(input_np, ctx) + filter_tvm = tvm.nd.array(filter_np, ctx) + depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx) + relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) + # launch kernel 1 (depthwise_conv2d) + timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1) + tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean + # launch kernel 2 (depthwise_conv2d + relu) + timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1) + tcost_2 = timer_2(input_tvm, filter_tvm, relu_tvm).mean + tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) + tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) + + for device in get_all_backend(): + with autotvm.tophub.context(device): # load tophub pre-tuned parameters + check_device(device) + + def test_depthwise_conv2d(): # mobilenet workloads depthwise_conv2d_with_workload_nchw(1, 32, 112, 1, 3, 1, "SAME") From b48a35e7c1bf895270982997e097c467cd72c848 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 31 Oct 2018 14:13:50 -0700 Subject: [PATCH 2/9] add test cases --- topi/python/topi/x86/conv2d.py | 51 +++++++------ .../python/test_topi_depthwise_conv2d.py | 74 +++++++++++-------- 2 files changed, 71 insertions(+), 54 deletions(-) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 1e613fde2fd2..912ae190ea00 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -10,26 +10,34 @@ from ..nn.conv2d import conv2d, conv2d_NCHWc, \ conv2d_alter_layout, _get_workload as _get_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload -from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc +from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw from ..nn.pad import pad from .util import get_fp32_len from . import conv2d_avx_1x1, conv2d_avx_common -def _get_default_config(cfg, workload): +def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): """ Get default schedule config for the workload Parameters ---------- workload : topi.nn.conv2d.Workload Convolution workload + is_depthwise : bool + Whether it is depthwise NCHW workload """ fp32_vec_len = get_fp32_len() - is_kernel_1x1 = workload.hkernel == 1 and workload.wkernel == 1 - if is_kernel_1x1: - conv2d_avx_1x1._fallback_schedule(cfg, workload, fp32_vec_len) + if is_depthwise: + wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype) + from depthwise_conv2d import fallback_schedule + fallback_schedule(cfg, wkl, fp32_vec_len) else: - conv2d_avx_common._fallback_schedule(cfg, workload, fp32_vec_len) + wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) + is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 + if is_kernel_1x1: + conv2d_avx_1x1._fallback_schedule(cfg, wkl, fp32_vec_len) + else: + conv2d_avx_common._fallback_schedule(cfg, wkl, fp32_vec_len) def _create_tuning_space(cfg, data, kernel, strides, padding, layout): @@ -66,8 +74,7 @@ def _declaration_conv(cfg, data, kernel, strides, padding, layout, out_dtype): if layout == 'NCHW': _create_tuning_space(cfg, data, kernel, strides, padding, layout) if cfg.is_fallback: - wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) - _get_default_config(cfg, wkl) + _get_default_config(cfg, data, kernel, strides, padding, out_dtype) return _declaration_conv_impl(cfg, data, kernel, strides, padding, layout, out_dtype) elif layout == 'HWCN': return nn.conv2d_hwcn(data, kernel, strides, padding, out_dtype) @@ -296,19 +303,17 @@ def _alter_conv2d_layout(attrs, inputs, tinfo): if groups != 1 and not is_depthwise: return None - workload = autotvm.task.args_to_workload( - [data, kernel, strides, padding, layout, out_dtype], conv2d) dispatch_ctx = autotvm.task.DispatchContext.current target = tvm.target.current_target() + # query schedule and fallback if necessary + workload = autotvm.task.args_to_workload( + [data, kernel, strides, padding, out_dtype], depthwise_conv2d_nchw) \ + if is_depthwise else \ + autotvm.task.args_to_workload( + [data, kernel, strides, padding, layout, out_dtype], conv2d) cfg = dispatch_ctx.query(target, workload) if cfg.is_fallback: - if is_depthwise: - wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype) - from depthwise_conv2d import fallback_schedule - fallback_schedule(cfg, wkl, get_fp32_len()) - else: - wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) - _get_default_config(cfg, wkl) + _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise) ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] new_attrs['layout'] = 'NCHW%dc' % ic_bn @@ -362,13 +367,11 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn - # get workload and related schedule config - wkl = _get_conv2d_workload(tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), - tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), - dtype=kernel.dtype), - strides, padding, out_dtype) if cfg.is_fallback: - _get_default_config(cfg, wkl) + _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), + tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), + dtype=kernel.dtype), + strides, padding, out_dtype) # output shape out_height = (ih + 2 * HPAD - kernel_height) // HSTR + 1 @@ -394,7 +397,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, n_elems = 4 assert ic_bn % n_elems == 0 - ic_outer = tvm.reduce_axis((0, wkl.in_filter//ic_bn), name='ic_outer') + ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer') ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py index 06cf5f729a90..c761210e4d2a 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d.py +++ b/topi/tests/python/test_topi_depthwise_conv2d.py @@ -205,26 +205,25 @@ def get_ref_data(): with autotvm.tophub.context(device): # load tophub pre-tuned parameters check_device(device) +def _transform_data(data, bn): + # NCHW -> NCHW[x]c + batch_size, channel, height, width = data.shape + data = np.transpose(data, (0, 2, 3, 1)) + data = np.reshape(data, (batch_size, height, width, channel//bn, bn)) + data = np.transpose(data, (0, 3, 1, 2, 4)) + return data + +def _transform_kernel(kernel, bn): + # channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block + channel, channel_multiplier, kh, kw = kernel.shape + out_channel = channel * channel_multiplier + kernel = np.transpose(kernel, (2, 3, 0, 1)) + kernel = np.reshape(kernel, (kh, kw, out_channel)) + kernel = np.reshape(kernel, (kh, kw, out_channel//bn, bn)) + kernel = np.transpose(kernel, (2, 0, 1, 3)) + return kernel def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1): - def _transform_data(data, bn): - # NCHW -> NCHW[x]c - batch_size, channel, height, width = data.shape - data = np.transpose(data, (0, 2, 3, 1)) - data = np.reshape(data, (batch_size, height, width, channel//bn, bn)) - data = np.transpose(data, (0, 3, 1, 2, 4)) - return data - - def _transform_kernel(kernel, bn): - # channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block - channel, channel_multiplier, kh, kw = kernel.shape - out_channel = channel * channel_multiplier - kernel = np.transpose(kernel, axes=(2, 3, 0, 1)) - kernel = np.reshape(kernel, shape=(kh, kw, out_channel)) - kernel = np.reshape(kernel, shape=(kh, kw, out_channel//bn, bn)) - kernel = np.transpose(kernel, axes=(2, 0, 1, 3)) - return kernel - in_width = in_height filter_channel = in_channel filter_width = filter_height @@ -240,7 +239,7 @@ def _transform_kernel(kernel, bn): # regardless of the performance. oc_block = 1 for bn in range(16, 0, -1): - if num_filter % bn == 0: + if out_channel % bn == 0: oc_block = bn break @@ -251,9 +250,10 @@ def _transform_kernel(kernel, bn): break # placeholder - Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input') - Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter') - + Input = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='Input') + Filter = tvm.placeholder((out_channel//oc_block, filter_height, filter_width, oc_block), name='Filter') + in_layout = "NCHW%dc" % ic_block + out_layout = "NCHW%dc" % oc_block dtype = 'float32' def check_device(device): @@ -266,8 +266,9 @@ def check_device(device): # declare DepthwiseConv2d = topi.nn.depthwise_conv2d_NCHWc(Input, Filter, (stride_h, stride_w), - padding_args, dtype) - # TODO: add scale_shift for NCHWc and add test here + padding_args, in_layout, + out_layout, dtype) + # TODO: add scale_shift implement for NCHWc and add test here Relu = topi.nn.relu(DepthwiseConv2d) # schedule s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d) @@ -277,11 +278,11 @@ def check_device(device): f2 = tvm.build(s2, [Input, Filter, Relu], device) # Prepare pod type for test data closure - input_shape = get_const_tuple(Input.shape) - filter_shape = get_const_tuple(Filter.shape) + input_shape = (batch, in_channel, in_height, in_width) + filter_shape = (filter_channel, channel_multiplier, filter_height, filter_width) # Use memoize, pickle the test data for next time use. - @memoize("topi.tests.test_topi_depthwise_conv2d.nchw") + @memoize("topi.tests.test_topi_depthwise_conv2d.NCHWc") def get_ref_data(): input_np = np.random.uniform(size=input_shape).astype(dtype) filter_np = np.random.uniform(size=filter_shape).astype(dtype) @@ -289,13 +290,18 @@ def get_ref_data(): depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw( input_np, filter_np, stride, padding) relu_scipy = np.maximum(depthwise_conv2d_scipy, 0) - return (input_np, filter_np, depthwise_conv2d_scipy, relu_scipy) + return (_transform_data(input_np, ic_block), + _transform_kernel(filter_np, oc_block), + _transform_data(depthwise_conv2d_scipy, oc_block), + _transform_data(relu_scipy, oc_block)) + # Get the test data (input_np, filter_np, depthwise_conv2d_scipy, relu_scipy) = get_ref_data() input_tvm = tvm.nd.array(input_np, ctx) filter_tvm = tvm.nd.array(filter_np, ctx) - depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx) + depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), + dtype=DepthwiseConv2d.dtype), ctx) relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) # launch kernel 1 (depthwise_conv2d) timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1) @@ -306,7 +312,8 @@ def get_ref_data(): tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) - for device in get_all_backend(): + # test llvm only for now since depthwise_conv2d_NCHWc implement is missing in other backend. + for device in ["llvm"]: with autotvm.tophub.context(device): # load tophub pre-tuned parameters check_device(device) @@ -339,5 +346,12 @@ def test_depthwise_conv2d(): # dilation = 2 depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2) + # NCHW[x]c + depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "SAME") + depthwise_conv2d_with_workload_NCHWc(4, 256, 64, 2, 5, 2, "SAME") + depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "VALID") + depthwise_conv2d_with_workload_NCHWc(4, 256, 64, 2, 5, 2, "VALID") + + if __name__ == "__main__": test_depthwise_conv2d() From 40cb824284d057c85dbb6c9a08b2faaed759e3a0 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 31 Oct 2018 14:57:20 -0700 Subject: [PATCH 3/9] lint codes --- topi/python/topi/nn/depthwise_conv2d.py | 2 +- topi/python/topi/x86/conv2d.py | 15 ++++---------- topi/python/topi/x86/conv2d_avx_1x1.py | 4 +++- topi/python/topi/x86/conv2d_avx_common.py | 4 +++- topi/python/topi/x86/depthwise_conv2d.py | 24 ++++++++++++++++++----- topi/python/topi/x86/util.py | 3 +-- 6 files changed, 31 insertions(+), 21 deletions(-) diff --git a/topi/python/topi/nn/depthwise_conv2d.py b/topi/python/topi/nn/depthwise_conv2d.py index 8782c3e30a5f..40ffbb4dedec 100644 --- a/topi/python/topi/nn/depthwise_conv2d.py +++ b/topi/python/topi/nn/depthwise_conv2d.py @@ -1,4 +1,4 @@ -# pylint: disable=invalid-name, unused-variable, too-many-locals +# pylint: disable=invalid-name, unused-variable, too-many-locals, unused-argument """Depthwise convolution operators""" from __future__ import absolute_import as _abs from collections import namedtuple diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 912ae190ea00..ae177c531a6b 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -19,25 +19,18 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): """ Get default schedule config for the workload - Parameters - ---------- - workload : topi.nn.conv2d.Workload - Convolution workload - is_depthwise : bool - Whether it is depthwise NCHW workload """ - fp32_vec_len = get_fp32_len() if is_depthwise: wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype) - from depthwise_conv2d import fallback_schedule - fallback_schedule(cfg, wkl, fp32_vec_len) + from depthwise_conv2d import _fallback_schedule + _fallback_schedule(cfg, wkl) else: wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 if is_kernel_1x1: - conv2d_avx_1x1._fallback_schedule(cfg, wkl, fp32_vec_len) + conv2d_avx_1x1._fallback_schedule(cfg, wkl) else: - conv2d_avx_common._fallback_schedule(cfg, wkl, fp32_vec_len) + conv2d_avx_common._fallback_schedule(cfg, wkl) def _create_tuning_space(cfg, data, kernel, strides, padding, layout): diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index ce70ec83828b..d44e3899293d 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -8,8 +8,10 @@ from ..util import get_const_tuple from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .check_targets import check_skylake +from .util import get_fp32_len -def _fallback_schedule(cfg, wkl, simd_width): +def _fallback_schedule(cfg, wkl): + simd_width = get_fp32_len() HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 diff --git a/topi/python/topi/x86/conv2d_avx_common.py b/topi/python/topi/x86/conv2d_avx_common.py index e52722ed54a7..1b8ee5fe9be4 100644 --- a/topi/python/topi/x86/conv2d_avx_common.py +++ b/topi/python/topi/x86/conv2d_avx_common.py @@ -8,8 +8,10 @@ from ..util import get_const_tuple from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .check_targets import check_skylake +from .util import get_fp32_len -def _fallback_schedule(cfg, wkl, simd_width): +def _fallback_schedule(cfg, wkl): + simd_width = get_fp32_len() HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index e6bf342c0420..180faba1166b 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -1,3 +1,5 @@ +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Depthwise Conv2D schedule on x86""" import tvm from tvm import autotvm from tvm.autotvm.task.space import SplitEntity @@ -9,7 +11,18 @@ from .util import get_fp32_len -def fallback_schedule(cfg, wkl, simd_width): +def _fallback_schedule(cfg, wkl): + """ + Get default schedule for the workload + Parameters + ---------- + cfg : tvm.autotvm.task.space.FallbackConfigEntity + Fallback config to be updated + wkl : topi.nn.depthwise_conv2d.Workload + Convolution workload + """ + simd_width = get_fp32_len() + HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 @@ -43,7 +56,8 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, layout, out_layout, out_dtype=None): out_dtype = data.dtype if out_dtype is None else out_dtype batch, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple(data.shape) - out_channel_chunk, filter_height, filter_width, out_channel_block = get_const_tuple(kernel.shape) + out_channel_chunk, filter_height, filter_width, out_channel_block \ + = get_const_tuple(kernel.shape) strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) HSTR, WSTR = strides @@ -62,7 +76,7 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dtype=kernel.dtype), strides, padding, out_dtype) if cfg.is_fallback: - fallback_schedule(cfg, wkl, get_fp32_len()) + _fallback_schedule(cfg, wkl) # padding stage DOPAD = (pad_top != 0 or pad_left != 0 or pad_down != 0 or pad_right != 0) @@ -80,8 +94,8 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, (batch, out_channel_chunk, out_height, out_width, out_channel_block), lambda b, oco, i, j, oci: tvm.sum( (data_pad[b, (oco * out_channel_block + oci) // channel_multiplier // in_channel_block, - i*HSTR+di, j*WSTR+dj, - ((oco * out_channel_block + oci) // channel_multiplier) % in_channel_block] + i*HSTR+di, j*WSTR+dj, + ((oco * out_channel_block + oci) // channel_multiplier) % in_channel_block] .astype(out_dtype) * kernel[oco, di, dj, oci].astype(out_dtype)), axis=[di, dj]), diff --git a/topi/python/topi/x86/util.py b/topi/python/topi/x86/util.py index d7117944cf40..678ff8e24cff 100644 --- a/topi/python/topi/x86/util.py +++ b/topi/python/topi/x86/util.py @@ -1,4 +1,3 @@ -# pylint: disable=invalid-name """Common x86 related utilities""" from __future__ import absolute_import as _abs import tvm @@ -10,4 +9,4 @@ def get_fp32_len(): for opt in target.options: if opt == '-mcpu=skylake-avx512': fp32_vec_len = 16 - return fp32_vec_len \ No newline at end of file + return fp32_vec_len From 59bca4d019511d23277acb1c6b7365c5aafd5f33 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 31 Oct 2018 15:04:42 -0700 Subject: [PATCH 4/9] fix lint --- topi/python/topi/x86/conv2d.py | 1 - topi/python/topi/x86/depthwise_conv2d.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index ae177c531a6b..c9e3f7f0f7c0 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -13,7 +13,6 @@ from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw from ..nn.pad import pad -from .util import get_fp32_len from . import conv2d_avx_1x1, conv2d_avx_common def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index 180faba1166b..552ad38fe401 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -105,6 +105,7 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, @autotvm.register_topi_schedule(generic.schedule_depthwise_conv2d_NCHWc, 'cpu', ['direct']) def schedule_depthwise_conv2d_NCHWc(cfg, outs): + """CPU schedule for depthwise conv2d in NCHW[x]c layout""" s = tvm.create_schedule([x.op for x in outs]) scheduled_ops = [] def traverse(op): @@ -119,9 +120,8 @@ def traverse(op): if 'depthwise_conv2d_NCHWc' in op.tag: conv_out = op.output(0) data = conv_out.op.input_tensors[0] - input = data kernel = conv_out.op.input_tensors[1] - _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, input, kernel, conv_out, outs[0]) + _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data, kernel, conv_out, outs[0]) scheduled_ops.append(op) traverse(outs[0].op) return s From 207c76ce81bf170f56445819dad315a309d236c3 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 3 Nov 2018 16:05:02 -0700 Subject: [PATCH 5/9] add dilation args --- nnvm/python/nnvm/top/nn.py | 2 +- topi/python/topi/nn/depthwise_conv2d.py | 6 +++++- topi/python/topi/x86/conv2d.py | 3 ++- topi/python/topi/x86/depthwise_conv2d.py | 5 ++++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 14a1c97f58f3..18748d56931a 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -174,7 +174,7 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): layout, out_layout, out_dtype) elif groups == in_channel and groups == channels: out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, - layout, out_layout, out_dtype) + dilation, layout, out_layout, out_dtype) # pylint: enable=assignment-from-no-return else: raise ValueError("not support arbitrary group number > 1 for now") diff --git a/topi/python/topi/nn/depthwise_conv2d.py b/topi/python/topi/nn/depthwise_conv2d.py index 87fccde50d7a..b5f46b840c9c 100644 --- a/topi/python/topi/nn/depthwise_conv2d.py +++ b/topi/python/topi/nn/depthwise_conv2d.py @@ -283,7 +283,8 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid @tvm.target.generic_func -def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, layout, out_layout, out_dtype=None): +def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation, + layout, out_layout, out_dtype=None): """Depthwise convolution NCHW[x]c forward operator. Parameters @@ -302,6 +303,9 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, layout, out_layout, o padding : int or str Padding size, or ['VALID', 'SAME'] + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + layout : str Input data layout diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index d044614f7d2e..325ab841c2ae 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -69,7 +69,8 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout) if cfg.is_fallback: _get_default_config(cfg, data, kernel, strides, padding, out_dtype) - return _declaration_conv_impl(cfg, data, kernel, strides, padding, layout, out_dtype) + return _declaration_conv_impl(cfg, data, kernel, strides, + padding, dilation, layout, out_dtype) elif layout == 'HWCN': return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype) elif layout == 'NHWC': diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index 552ad38fe401..f1846c267e57 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -52,7 +52,7 @@ def _fallback_schedule(cfg, wkl): @autotvm.register_topi_compute(depthwise_conv2d_NCHWc, 'cpu', 'direct') -def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, +def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype=None): out_dtype = data.dtype if out_dtype is None else out_dtype batch, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple(data.shape) @@ -63,6 +63,9 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, HSTR, WSTR = strides pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (filter_height, filter_width)) + dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + assert (dh, dw) == (1, 1), "Does not support dilation" + in_channel = in_channel_chunk * in_channel_block out_channel = out_channel_chunk * out_channel_block channel_multiplier = out_channel // in_channel From d782ac53bfccc26a8d9519aa49fbf10f098b2f28 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 3 Nov 2018 22:34:27 -0700 Subject: [PATCH 6/9] clean code; enable conv2d_NCHWc tests --- nnvm/python/nnvm/top/nn.py | 8 +++---- topi/python/topi/x86/depthwise_conv2d.py | 15 ++++++------ topi/tests/python/test_topi_conv2d_NCHWc.py | 23 +++++++++---------- .../python/test_topi_depthwise_conv2d.py | 21 ++++++++--------- 4 files changed, 31 insertions(+), 36 deletions(-) diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 18748d56931a..0802d0e1d1a1 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -159,7 +159,7 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): padding = attrs.get_int_tuple("padding") strides = attrs.get_int_tuple("strides") dilation = attrs.get_int_tuple("dilation") - channels = attrs.get_int("channels") + out_channel = attrs.get_int("channels") groups = attrs.get_int("groups") layout = attrs.get_string("layout") out_layout = attrs.get_string("out_layout") @@ -172,7 +172,7 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): # pylint: disable=assignment-from-no-return out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation, layout, out_layout, out_dtype) - elif groups == in_channel and groups == channels: + elif groups == in_channel and groups == out_channel: out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation, layout, out_layout, out_dtype) # pylint: enable=assignment-from-no-return @@ -188,11 +188,11 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): def schedule_contrib_conv2d_NCHWc(attrs, outs, target): """Schedule definition of conv2d NCHWc""" groups = attrs.get_int("groups") - channels = attrs.get_int("channels") + out_channel = attrs.get_int("channels") with tvm.target.create(target): if groups == 1: return topi.generic.schedule_conv2d_NCHWc(outs) - elif groups == channels: + elif groups == out_channel: return topi.generic.schedule_depthwise_conv2d_NCHWc(outs) else: raise ValueError("not support group number > 1 for now") diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index f1846c267e57..3460755ce112 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -91,17 +91,17 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation, data_pad = data # depthconv stage - di = tvm.reduce_axis((0, filter_height), name='di') - dj = tvm.reduce_axis((0, filter_width), name='dj') + kh = tvm.reduce_axis((0, filter_height), name='kh') + kw = tvm.reduce_axis((0, filter_width), name='kw') Output = tvm.compute( (batch, out_channel_chunk, out_height, out_width, out_channel_block), - lambda b, oco, i, j, oci: tvm.sum( + lambda b, oco, oh, ow, oci: tvm.sum( (data_pad[b, (oco * out_channel_block + oci) // channel_multiplier // in_channel_block, - i*HSTR+di, j*WSTR+dj, + oh*HSTR+kh, ow*WSTR+kw, ((oco * out_channel_block + oci) // channel_multiplier) % in_channel_block] .astype(out_dtype) * - kernel[oco, di, dj, oci].astype(out_dtype)), - axis=[di, dj]), + kernel[oco, kh, kw, oci].astype(out_dtype)), + axis=[kh, kw]), name='DepthwiseConv2d', tag="depthwise_conv2d_NCHWc") return Output @@ -144,10 +144,8 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data, kernel, conv_out, output _, ic_chunk, oh, ow, ic_block = s[C].op.axis ow_chunk, ow_block = s[C].split(ow, factor=tile_ow) s[C].reorder(ic_chunk, oh, ow_chunk, ow_block, ic_block) - s[C].vectorize(ic_block) parallel_axis = s[C].fuse(ic_chunk, oh) s[C].parallel(parallel_axis) - s[C].unroll(ow_block) s[CC].compute_at(s[C], ow_chunk) _, ic_chunk, oh, ow, ic_block = s[CC].op.axis @@ -156,6 +154,7 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data, kernel, conv_out, output s[CC].reorder(ic_chunk, oh, kh, kw, ow_block, ic_block) s[CC].vectorize(ic_block) s[CC].unroll(ow_block) + if C != O: batch, oc_chunk, oh, ow, oc_block = s[O].op.axis ow_chunk, ow_block = s[O].split(ow, factor=tile_ow) diff --git a/topi/tests/python/test_topi_conv2d_NCHWc.py b/topi/tests/python/test_topi_conv2d_NCHWc.py index 38e6ad6d9e7c..a3af43c8d810 100644 --- a/topi/tests/python/test_topi_conv2d_NCHWc.py +++ b/topi/tests/python/test_topi_conv2d_NCHWc.py @@ -13,27 +13,22 @@ def _transform_data(data, bn): # NCHW -> NCHW[x]c batch_size, channel, height, width = data.shape - data = np.transpose(data, (0, 2, 3, 1)) - data = np.reshape(data, (batch_size, height, width, channel//bn, bn)) - data = np.transpose(data, (0, 3, 1, 2, 4)) + data = np.reshape(data, (batch_size, channel//bn, bn, height, width)) + data = np.transpose(data, (0, 1, 3, 4, 2)) return data def _transform_kernel(kernel, ic_bn, oc_bn): # OIHW -> OIHW[x]i[x]o out_channel, in_channel, kh, kw = kernel.shape - kernel = np.transpose(kernel, (1, 2, 3, 0)) - kernel = np.reshape(kernel, (in_channel, kh, kw, out_channel//oc_bn, oc_bn)) - kernel = np.transpose(kernel, (1, 2, 3, 4, 0)) - kernel = np.reshape(kernel, (kh, kw, out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn)) - kernel = np.transpose(kernel, (2, 4, 0, 1, 5, 3)) + kernel = np.reshape(kernel, (out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn, kh, kw)) + kernel = np.transpose(kernel, (0, 2, 4, 5, 3, 1)) return kernel def _transform_bias(bias, bn): # [num_filter, 1, 1] -> [num_filter//bn, 1, 1, bn] num_filter, h, w = bias.shape - bias = np.transpose(bias, (1, 2, 0)) - bias = np.reshape(bias, (h, w, num_filter//bn, bn)) - bias = np.transpose(bias, (2, 0, 1, 3)) + bias = np.reshape(bias, (num_filter//bn, bn, h, w)) + bias = np.transpose(bias, (0, 2, 3, 1)) return bias def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride, @@ -86,6 +81,7 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding), + (dilation, dilation), layout='NCHW%dc'%ic_block, out_layout="NCHW%dc"%oc_block, out_dtype=dtype) @@ -117,7 +113,7 @@ def check_device(device): check_device(device) -if __name__ == "__main__": +def test_conv2d_NCHWc(): # ResNet18 workloads verify_conv2d_NCHWc(1, 3, 224, 64, 7, 2, 3) verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1) @@ -204,3 +200,6 @@ def check_device(device): verify_conv2d_NCHWc(1, 2048, 10, 126, 3, 1, 1) verify_conv2d_NCHWc(1, 512, 5, 126, 3, 1, 1) verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1) + +if __name__ == "__main__": + test_conv2d_NCHWc() \ No newline at end of file diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py index 4c2ef67226e3..98c93dff9993 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d.py +++ b/topi/tests/python/test_topi_depthwise_conv2d.py @@ -206,19 +206,16 @@ def get_ref_data(): def _transform_data(data, bn): # NCHW -> NCHW[x]c batch_size, channel, height, width = data.shape - data = np.transpose(data, (0, 2, 3, 1)) - data = np.reshape(data, (batch_size, height, width, channel//bn, bn)) - data = np.transpose(data, (0, 3, 1, 2, 4)) + data = np.reshape(data, (batch_size, channel//bn, bn, height, width)) + data = np.transpose(data, (0, 1, 3, 4, 2)) return data def _transform_kernel(kernel, bn): # channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block channel, channel_multiplier, kh, kw = kernel.shape out_channel = channel * channel_multiplier - kernel = np.transpose(kernel, (2, 3, 0, 1)) - kernel = np.reshape(kernel, (kh, kw, out_channel)) - kernel = np.reshape(kernel, (kh, kw, out_channel//bn, bn)) - kernel = np.transpose(kernel, (2, 0, 1, 3)) + kernel = np.reshape(kernel, (out_channel//bn, bn, kh, kw)) + kernel = np.transpose(kernel, (0, 2, 3, 1)) return kernel def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1): @@ -264,7 +261,9 @@ def check_device(device): # declare DepthwiseConv2d = topi.nn.depthwise_conv2d_NCHWc(Input, Filter, (stride_h, stride_w), - padding_args, in_layout, + padding_args, + (dilation, dilation), + in_layout, out_layout, dtype) # TODO: add scale_shift implement for NCHWc and add test here Relu = topi.nn.relu(DepthwiseConv2d) @@ -302,11 +301,9 @@ def get_ref_data(): dtype=DepthwiseConv2d.dtype), ctx) relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) # launch kernel 1 (depthwise_conv2d) - timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1) - tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean + f1(input_tvm, filter_tvm, depthwise_conv2d_tvm) # launch kernel 2 (depthwise_conv2d + relu) - timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1) - tcost_2 = timer_2(input_tvm, filter_tvm, relu_tvm).mean + f2(input_tvm, filter_tvm, relu_tvm) tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) From 96763e4a6db546ccdf33dcf67f9683b532168312 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 3 Nov 2018 22:41:13 -0700 Subject: [PATCH 7/9] simplify kernel layout transform --- topi/python/topi/x86/conv2d.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 325ab841c2ae..7e0b90f1db9b 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -327,11 +327,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfo): dtype=data.dtype) if is_depthwise: # channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block + # in which out_channel = merge(channel, channel_multiplier) kernel_sym = copy_inputs[1] - kernel_sym = sym.transpose(kernel_sym, axes=(2, 3, 0, 1)) - kernel_sym = sym.reshape(kernel_sym, shape=(kh, kw, out_channel)) - kernel_sym = sym.reshape(kernel_sym, shape=(kh, kw, out_channel//oc_bn, oc_bn)) - kernel_sym = sym.transpose(kernel_sym, axes=(2, 0, 1, 3)) + kernel_sym = sym.reshape(kernel_sym, shape=(out_channel//oc_bn, oc_bn, kh, kw)) + kernel_sym = sym.transpose(kernel_sym, axes=(0, 2, 3, 1)) copy_inputs[1] = kernel_sym # Store altered operator's config From 5aef217cd46403257924d0c2fae1f698720e3e6e Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 11 Nov 2018 15:59:53 -0800 Subject: [PATCH 8/9] fix merging upstream --- topi/python/topi/generic/nn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 47fe39df002a..8c303e5be182 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -191,6 +191,7 @@ def schedule_depthwise_conv2d_NCHWc(outs): return _default_schedule(outs, False) +@tvm.target.generic_func def schedule_group_conv2d_nchw(outs): """Schedule for conv2d_nchw From 8fef6aec6b7590397e544d9499cf0b0e635e55f5 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 11 Nov 2018 22:30:34 -0800 Subject: [PATCH 9/9] add depthwise conv2d NCHWc tuning support for x86, update tutorial --- topi/python/topi/x86/depthwise_conv2d.py | 37 ++++++++++++++++++++++++ tutorials/autotvm/tune_nnvm_x86.py | 10 ++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index 3460755ce112..8f37a0316229 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -2,7 +2,9 @@ """Depthwise Conv2D schedule on x86""" import tvm from tvm import autotvm +from tvm.autotvm.task import get_config from tvm.autotvm.task.space import SplitEntity +from tvm.autotvm.task.nnvm_integration import deserialize_args from .. import generic, tag from ..nn.pad import pad from ..util import get_const_tuple @@ -164,3 +166,38 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data, kernel, conv_out, output s[O].vectorize(oc_block) s[O].parallel(parallel_axis) return s + + +@autotvm.task.register("topi_x86_depthwise_conv2d_NCHWc_from_nchw") +def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + data, kernel, strides, padding, dilation, dtype = deserialize_args(args) + + batch, in_channel, height, width = get_const_tuple(data.shape) + filter_channel, channel_multiplier, kh, kw = get_const_tuple(kernel.shape) + ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) + sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) + out_height = (height - kh + 2 * ph) // sh + 1 + out_width = (width - kw + 2 * pw) // sw + 1 + out_channel = filter_channel * channel_multiplier + + # get config here + cfg = get_config() + cfg.define_split("tile_ic", in_channel, num_outputs=2) + cfg.define_split("tile_oc", out_channel, num_outputs=2) + cfg.define_split("tile_ow", out_width, num_outputs=2, filter=lambda y: y.size[-1] <= 64) + + # change shape with the value in config + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + new_data_shape = (batch, in_channel // ic_bn, height, width, ic_bn) + new_kernel_shape = (out_channel // oc_bn, kh, kw, oc_bn) + new_data = tvm.placeholder(new_data_shape, data.dtype) + new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype) + + data_layout = "NCHW%dc" % ic_bn + out_layout = "NCHW%dc" % oc_bn + + C = _depthwise_conv2d_NCHWc_cpu(cfg, new_data, new_kernel, strides, padding, dilation, + data_layout, out_layout, dtype) + s = schedule_depthwise_conv2d_NCHWc(cfg, [C]) + return s, [new_data, new_kernel, C] diff --git a/tutorials/autotvm/tune_nnvm_x86.py b/tutorials/autotvm/tune_nnvm_x86.py index 18f1117dc68a..9f8692c3981e 100644 --- a/tutorials/autotvm/tune_nnvm_x86.py +++ b/tutorials/autotvm/tune_nnvm_x86.py @@ -117,7 +117,15 @@ def tune_kernels(tasks, prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) # converting conv2d tasks to conv2d_NCHWc tasks - task = autotvm.task.create("topi_x86_conv2d_NCHWc", args=tsk.args, + op_name = tsk.workload[0] + if op_name == 'conv2d': + func_create = 'topi_x86_conv2d_NCHWc' + elif op_name == 'depthwise_conv2d_nchw': + func_create = 'topi_x86_depthwise_conv2d_NCHWc_from_nchw' + else: + raise ValueError("Tuning {} is not supported on x86".format(op_name)) + + task = autotvm.task.create(func_create, args=tsk.args, target=target, template_key='direct') task.workload = tsk.workload