Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,15 @@ def _topi_nn_conv2d(*args, **kwargs):
args = deserialize_args(args)
A, W = args[:2]
layout = args[-2]
assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently"
C = topi.nn.conv2d(*args, **kwargs)
if layout == 'NCHW':
s = topi.generic.schedule_conv2d_nchw([C])
else:
elif layout == 'HWCN':
s = topi.generic.schedule_conv2d_hwcn([C])
elif layout == 'NHWC':
s = topi.generic.schedule_conv2d_nhwc([C])
else:
raise ValueError("Unsupported layout {}".format(layout))
return s, [A, W, C]

@register("topi_nn_depthwise_conv2d_nchw")
Expand Down
38 changes: 36 additions & 2 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from tvm import autotvm
import tvm.contrib.nnpack

from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
from ..generic import schedule_conv2d_nchw, schedule_conv2d_nhwc, \
schedule_conv2d_winograd_without_weight_transform, \
schedule_conv2d_winograd_nnpack_without_weight_transform
from ..util import traverse_inline, get_const_tuple
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
Expand All @@ -34,7 +35,9 @@
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices
from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \
schedule_conv2d_spatial_pack_nchw
conv2d_spatial_pack_nhwc, \
schedule_conv2d_spatial_pack_nchw, \
schedule_conv2d_spatial_pack_nhwc

logger = logging.getLogger('topi')

Expand Down Expand Up @@ -78,6 +81,9 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dt
if layout == 'NCHW':
return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
dilation, out_dtype, num_tile=2)
elif layout == 'NHWC':
return conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding,
dilation, out_dtype)
else:
raise ValueError("Unsupported layout {}".format(layout))

Expand Down Expand Up @@ -136,6 +142,34 @@ def _callback(op):
traverse_inline(s, outs[0].op, _callback)
return s

@autotvm.register_topi_schedule(schedule_conv2d_nhwc, 'arm_cpu', ['direct'])
def schedule_conv2d_nhwc_arm_cpu(cfg, outs):
"""TOPI schedule callback for conv2d

Parameters
----------
cfg: ConfigEntity
The config for this template

outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for conv2d.
"""
s = tvm.create_schedule([x.op for x in outs])

def _callback(op):
if 'spatial_conv_output_NHWC' in op.tag:
schedule_conv2d_spatial_pack_nhwc(cfg, s, op, outs[0])

traverse_inline(s, outs[0].op, _callback)
return s


@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd'])
def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
""" TOPI compute callback. Use winograd template """
Expand Down
160 changes: 160 additions & 0 deletions topi/python/topi/arm_cpu/conv2d_spatial_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,163 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
s[kernel_vec].parallel(co)

return s

def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""Spatial pack compute for Conv2d NHWC"""
out_dtype = out_dtype or data.dtype

N, IH, IW, IC = get_const_tuple(data.shape)
assert len(kernel.shape) == 4, "AlterOpLayout not enabled for NHWC yet"
KH, KW, _, OC = get_const_tuple(kernel.shape)

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation

dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1

pad_top, pad_left, pad_down, pad_right = \
get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)

OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0])

# ==================== define configuration space ====================
n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW)
ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW)

oco, oci = cfg.define_split('tile_co', oc, num_outputs=2)
oho, ohi = cfg.define_split('tile_oh', oh, num_outputs=2)
owo, owi = cfg.define_split('tile_ow', ow, num_outputs=2)

cfg.define_reorder('reorder_conv',
[n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
policy='candidate', candidate=[
[n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
[n, oho, owo, oco, ohi, kh, kw, ic, owi, oci],
[n, oho, owo, oco, ohi, kh, kw, owi, ic, oci],
[n, oho, owo, ohi, oco, kh, kw, owi, ic, oci]])

cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
cfg.define_annotate("ann_spatial", [ohi, owi, oci], policy='try_unroll_vec')
# ====================================================================

OCI = cfg['tile_co'].size[-1]
OHI = cfg['tile_oh'].size[-1]
OWI = cfg['tile_ow'].size[-1]
OCO = OC // OCI
OHO = OH // OHI
OWO = OW // OWI

kvshape = (OCO, KH, KW, IC, OCI)
ovshape = (N, OHO, OWO, OCO, OHI, OWI, OCI)
oshape = (N, OH, OW, OC)

if dilation_h != 1 or dilation_w != 1:
# undilate input data
dvshape = (N, OHO, OWO, KH, KW, IC, OHI, OWI)
data_vec = tvm.compute(dvshape, lambda n, oho, owo, kh, kw, ic, ohi, owi:
data_pad[n][(oho*OHI+ohi)*HSTR+kh*dilation_h]
[(owo*OWI+owi)*WSTR+kw*dilation_w][ic],
name='data_vec_undilated')
else:
dvshape = (N, OHO, OWO, KH + (OHI-1)*HSTR, KW + (OWI-1)*WSTR, IC)
data_vec = tvm.compute(dvshape, lambda n, oho, owo, ohi, owi, ic:
data_pad[n][oho*OHI*HSTR+ohi][owo*OWI*WSTR+owi][ic],
name='data_vec')
kernel_vec = tvm.compute(kvshape, lambda oco, kh, kw, ic, oci: \
kernel[kh][kw][ic][oco*OCI+oci],
name='kernel_vec')

ic = tvm.reduce_axis((0, IC), name='ic')
kh = tvm.reduce_axis((0, KH), name='kh')
kw = tvm.reduce_axis((0, KW), name='kw')

if dilation_h != 1 or dilation_w != 1:
conv = tvm.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \
tvm.sum(data_vec[n, oho, owo, kh, kw, ohi, owi, ic].astype(out_dtype) *
kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
axis=[ic, kh, kw]), name='conv')
else:
conv = tvm.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \
tvm.sum(data_vec[n, oho, owo, ohi*HSTR+kh, owi*WSTR+kw, ic].astype(out_dtype) *
kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
axis=[ic, kh, kw]), name='conv')

idiv = tvm.indexdiv
imod = tvm.indexmod
output = tvm.compute(oshape, lambda n, oho, owo, oc:
conv[n][idiv(oho, OHI)][idiv(owo, OWI)][idiv(oc, OCI)]\
[imod(oho, OHI)][imod(owo, OWI)][imod(oc, OCI)],
name='output_unpack', tag='spatial_conv_output_NHWC')
return output

def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output):
"""Spatial Pack schedule for Conv2d NHWC"""
unpack = op.output(0)
conv = unpack.op.input_tensors[0]
data_vec = conv.op.input_tensors[0]
kernel_vec = conv.op.input_tensors[1]
data_pad = data_vec.op.input_tensors[0]
OHI = cfg['tile_oh'].size[-1]
OWI = cfg['tile_ow'].size[-1]
OCI = cfg['tile_co'].size[-1]

# schedule unpack/output
if output != unpack:
s[unpack].compute_inline()
n, oh, ow, oc = s[output].op.axis
oco, oci = cfg['tile_co'].apply(s, output, oc)
oho, ohi = cfg['tile_oh'].apply(s, output, oh)
owo, owi = cfg['tile_ow'].apply(s, output, ow)
s[output].reorder(n, oho, owo, oco, ohi, owi, oci)
cfg['ann_spatial'].apply(s, output, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI],
max_unroll=16, cfg=cfg)
cfg.define_knob('compat', [0, 1, 2])
if cfg['compat'].val < 2:
compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706
s[conv].compute_at(s[output], compat_axis)
paxis = s[output].fuse(n, oho)
s[output].parallel(paxis)

# schedule conv
n, oho, owo, oco, ohi, owi, oci = s[conv].op.axis
ic, kh, kw = s[conv].op.reduce_axis
cfg['reorder_conv'].apply(s, conv, [n, oho, owo, oco, kh, kw, ohi, owi, ic, oci])
cfg['ann_reduce'].apply(s, conv, [kh, kw],
axis_lens=[get_const_int(kh.dom.extent),
get_const_int(kw.dom.extent)],
max_unroll=16,
cfg=cfg)
cfg['ann_spatial'].apply(s, conv, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI],
max_unroll=16, cfg=cfg)
if cfg['compat'].val < 2:
compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706
s[kernel_vec].compute_at(s[conv], compat_axis)
s[data_vec].compute_at(s[conv], compat_axis)

# schedule kernel pack
oco, kh, kw, ic, oci = kernel_vec.op.axis
s[kernel_vec].vectorize(oci)
s[kernel_vec].unroll(ic)
if cfg['compat'].val == 2:
s[kernel_vec].parallel(oco)

# schedule data pack
if data_vec.op.name == 'data_vec_undilated':
n, oho, owo, kh, kw, ic, ohi, owi = s[data_vec].op.axis
s[data_vec].vectorize(owi)
s[data_vec].unroll(ohi)
else:
n, oho, owo, ohi, owi, ic = s[data_vec].op.axis
s[data_vec].vectorize(ic)
s[data_vec].unroll(owi)
if cfg['compat'].val == 2:
paxis = s[data_vec].fuse(n, oho)
s[data_vec].parallel(paxis)

return s