Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import tvm
import topi
from topi.util import get_const_int
from topi.util import get_const_int, get_const_tuple
from .tensor import _fschedule_broadcast, _fschedule_injective
from . import registry as reg
from .registry import OpPattern
Expand Down Expand Up @@ -164,16 +164,22 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
out_channel = attrs.get_int("channels")
groups = attrs.get_int("groups")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
out_dtype = attrs.get_string("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
_, in_channel_chunk, _, _, in_channel_block = get_const_tuple(inputs[0].shape)
in_channel = in_channel_chunk * in_channel_block
assert dilation == (1, 1), "not support dilate now"
if groups == 1:
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
layout, out_layout, out_dtype)
elif groups == in_channel and groups == out_channel:
out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding,
dilation, layout, out_layout, out_dtype)
# pylint: enable=assignment-from-no-return
else:
raise ValueError("not support arbitrary group number > 1 for now")
Expand All @@ -187,9 +193,12 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of conv2d NCHWc"""
groups = attrs.get_int("groups")
out_channel = attrs.get_int("channels")
with tvm.target.create(target):
if groups == 1:
return topi.generic.schedule_conv2d_NCHWc(outs)
elif groups == out_channel:
return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)
else:
raise ValueError("not support group number > 1 for now")

Expand Down
4 changes: 3 additions & 1 deletion nnvm/src/top/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,

wshape[kernel_layout.indexof('O')] *= param.groups;

NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
if (in_shape->at(Conv2DParam::kWeight).ndim() == 0) {
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
}
if (param.use_bias) {
static const Layout default_bias_layout("C");
TShape bias_shape({param.channels});
Expand Down
17 changes: 17 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,23 @@ def schedule_depthwise_conv2d_nhwc(outs):
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_depthwise_conv2d_NCHWc(outs):
"""Schedule for depthwise_conv2d_NCHWc
Parameters
----------
outs: Array of Tensor
The computation graph description of depthwise_conv2d_nhwc
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_group_conv2d_nchw(outs):
"""Schedule for conv2d_nchw
Expand Down
65 changes: 64 additions & 1 deletion topi/python/topi/nn/depthwise_conv2d.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=invalid-name, unused-variable, too-many-locals, unused-argument
"""Depthwise convolution operators"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm

from .dilate import dilate
from .pad import pad
from .util import get_pad_tuple
from ..util import simplify

# workload description of depthwise-conv2d
Workload = namedtuple('Workload',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want dilation in here? Or that's for separate PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workload here is for getting default schedule, since dilation so far does not impact how we calculate configs, I'd rather keep it simple for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. This is resolved from my side.

['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])

def _get_workload(data, kernel, stride, padding, out_dtype):
""" Get the workload structure. """
_, in_channel, height, width = [x.value for x in data.shape]
channel, channel_multiplier, kh, kw = [x.value for x in kernel.shape]
out_channel = channel * channel_multiplier
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, height, width, in_channel,
out_channel, kh, kw, HPAD, WPAD, HSTR, WSTR)


@tvm.target.generic_func
def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
Expand Down Expand Up @@ -258,3 +280,44 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid
tag='depthwise_conv2d_backward_weight_nhwc')

return Weight_grad


@tvm.target.generic_func
def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
layout, out_layout, out_dtype=None):
"""Depthwise convolution NCHW[x]c forward operator.

Parameters
----------
Input : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]

Filter : tvm.Tensor
4-D with shape [out_channel_chunk, filter_height, filter_width, out_channel_block]
In NCHWc depthwise convolution,
we group kernel's in_channel and channel_multiplier together then do the tiling.

stride : tuple of two ints
The spatial stride along height and width

padding : int or str
Padding size, or ['VALID', 'SAME']

dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]

layout : str
Input data layout

out_layout : str
Output data layout

out_dtype: str, optional
Output data type

Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc")
1 change: 1 addition & 0 deletions topi/python/topi/x86/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .injective import *
from .pooling import schedule_pool, schedule_global_pool
from .bitserial_conv2d import schedule_bitserial_conv2d
from .depthwise_conv2d import schedule_depthwise_conv2d_NCHWc
114 changes: 65 additions & 49 deletions topi/python/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,30 @@
from .. import generic, tag
from .. import nn
from ..util import get_const_tuple
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, _get_workload
from ..nn.conv2d import conv2d, conv2d_NCHWc, \
conv2d_alter_layout, _get_workload as _get_conv2d_workload
from ..nn.dilate import dilate
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
from ..nn.pad import pad

from . import conv2d_avx_1x1, conv2d_avx_common

def _get_fp32_len():
fp32_vec_len = 8
target = tvm.target.current_target()
if target is not None:
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
fp32_vec_len = 16
return fp32_vec_len


def _get_default_config(cfg, workload):
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False):
"""
Get default schedule config for the workload
Parameters
----------
workload : topi.nn.conv2d.Workload
Convolution workload
"""
fp32_vec_len = _get_fp32_len()
is_kernel_1x1 = workload.hkernel == 1 and workload.wkernel == 1
if is_kernel_1x1:
conv2d_avx_1x1._fallback_schedule(cfg, workload, fp32_vec_len)
if is_depthwise:
wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
from depthwise_conv2d import _fallback_schedule
_fallback_schedule(cfg, wkl)
else:
conv2d_avx_common._fallback_schedule(cfg, workload, fp32_vec_len)
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
if is_kernel_1x1:
conv2d_avx_1x1._fallback_schedule(cfg, wkl)
else:
conv2d_avx_common._fallback_schedule(cfg, wkl)


def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
Expand Down Expand Up @@ -74,10 +68,9 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out
if layout == 'NCHW':
_create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout)
if cfg.is_fallback:
wkl = _get_workload(data, kernel, strides, padding, out_dtype)
_get_default_config(cfg, wkl)
return _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout,
out_dtype)
_get_default_config(cfg, data, kernel, strides, padding, out_dtype)
return _declaration_conv_impl(cfg, data, kernel, strides,
padding, dilation, layout, out_dtype)
elif layout == 'HWCN':
return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
elif layout == 'NHWC':
Expand Down Expand Up @@ -295,44 +288,69 @@ def _alter_conv2d_layout(attrs, inputs, tinfo):
copy_inputs = [s for s in inputs]
new_attrs = {k : attrs[k] for k in attrs.keys()}
data, kernel = tinfo[0], tinfo[1]
# only optimize for NCHW, groups=1 conv
if attrs['layout'] != 'NCHW' or attrs.get_int("groups") != 1:
return None
batch_size, in_channel, height, width = get_const_tuple(data.shape)
out_channel, _, kh, kw = get_const_tuple(kernel.shape)

groups = attrs.get_int("groups")
out_channel = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
layout = attrs['layout']
kh, kw = attrs.get_int_tuple("kernel_size")

dtype = data.dtype
out_dtype = dtype if attrs["out_dtype"] == "same" else attrs["out_dtype"]
is_depthwise = groups == in_channel and groups == out_channel

# only optimize for NCHW
if layout != 'NCHW':
return None
if groups != 1 and not is_depthwise:
return None

workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, layout, out_dtype], conv2d)
dispatch_ctx = autotvm.task.DispatchContext.current
target = tvm.target.current_target()
# query schedule and fallback if necessary
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, out_dtype], depthwise_conv2d_nchw) \
if is_depthwise else \
autotvm.task.args_to_workload(
[data, kernel, strides, padding, layout, out_dtype], conv2d)
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback:
wkl = _get_workload(data, kernel, strides, padding, out_dtype)
_get_default_config(cfg, wkl)
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise)

ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
new_attrs['layout'] = 'NCHW%dc' % ic_bn
new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)

# Store the same config for the altered operator (workload)
new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
dtype=data.dtype)
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, new_attrs['layout'],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)
if is_depthwise:
# channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block
# in which out_channel = merge(channel, channel_multiplier)
kernel_sym = copy_inputs[1]
kernel_sym = sym.reshape(kernel_sym, shape=(out_channel//oc_bn, oc_bn, kh, kw))
kernel_sym = sym.transpose(kernel_sym, axes=(0, 2, 3, 1))
copy_inputs[1] = kernel_sym

# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, new_attrs['layout'],
new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc)
else:
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)

# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, new_attrs['layout'],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)

dispatch_ctx.update(target, new_workload, cfg)
return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)


Expand All @@ -354,13 +372,11 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn

# get workload and related schedule config
wkl = _get_workload(tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
dtype=kernel.dtype),
strides, padding, out_dtype)
if cfg.is_fallback:
_get_default_config(cfg, wkl)
_get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
dtype=kernel.dtype),
strides, padding, out_dtype)

# output shape
out_height = (ih + 2 * HPAD - kernel_height) // HSTR + 1
Expand All @@ -386,7 +402,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
n_elems = 4
assert ic_bn % n_elems == 0

ic_outer = tvm.reduce_axis((0, wkl.in_filter//ic_bn), name='ic_outer')
ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
Expand Down
4 changes: 3 additions & 1 deletion topi/python/topi/x86/conv2d_avx_1x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
from .util import get_fp32_len

def _fallback_schedule(cfg, wkl, simd_width):
def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
Expand Down
4 changes: 3 additions & 1 deletion topi/python/topi/x86/conv2d_avx_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
from .util import get_fp32_len

def _fallback_schedule(cfg, wkl, simd_width):
def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
Expand Down
Loading