Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions nnvm/src/top/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,13 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(param.channels % param.groups, 0U)
<< "output channels must divide group size";

TShape wshape({param.channels / param.groups,
TShape wshape({param.channels,
dshape[1] / param.groups,
param.kernel_size[0],
param.kernel_size[1]});

wshape = ConvertLayout(wshape, kOIHW, kernel_layout);

wshape[kernel_layout.indexof('O')] *= param.groups;

if (in_shape->at(Conv2DParam::kWeight).ndim() == 0) {
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
}
Expand Down
3 changes: 1 addition & 2 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,11 @@ bool Conv2DRel(const Array<Type>& types,
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
std::vector<IndexExpr> wshape(
{param->channels / param->groups,
{param->channels,
dshape_nchw[1] / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
wshape[kernel_layout.Indexof('O')] *= param->groups;
channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
Expand Down
135 changes: 83 additions & 52 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform
from ..util import traverse_inline, get_const_tuple, const_matrix
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, conv2d_winograd_without_weight_transform
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
conv2d_winograd_without_weight_transform, depthwise_conv2d_nchw
from ..nn.util import get_const_int, get_pad_tuple

@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
Expand Down Expand Up @@ -556,7 +557,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
if out_dtype == "" or out_dtype == "same":
out_dtype = tinfos[0].dtype

if layout != 'NCHW' or groups != 1:
if layout != 'NCHW':
return None
if dilation != (1, 1):
warnings.warn("Does not support weight pre-transform for dilated convolution.")
Expand All @@ -566,54 +567,84 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
N, CI, H, W = get_const_tuple(data.shape)
CO, _, KH, KW = get_const_tuple(kernel.shape)

# query config of this workload
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
target = tvm.target.current_target()
dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload)

if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None

if cfg.template_key == 'direct': # pack weight tensor
VC = cfg['tile_co'].size[-1]
new_attrs['kernel_layout'] = 'OIHW%do' % VC

# Store the same config for the altered operator (workload)
new_data = data
new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
dispatch_ctx.update(target, new_workload, cfg)

return F.nn.conv2d(*copy_inputs, **new_attrs)
else: # pre-compute weight transformation in winograd
if "-device=arm_cpu" in target.options:
tile_size = 4
VC = cfg['tile_k'].size[-1]
if groups == 1:
# query config of this workload
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
target = tvm.target.current_target()
dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload)

if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None

if cfg.template_key == 'direct': # pack weight tensor
VC = cfg['tile_co'].size[-1]
new_attrs['kernel_layout'] = 'OIHW%do' % VC

# Store the same config for the altered operator (workload)
new_data = data
new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
dispatch_ctx.update(target, new_workload, cfg)

return F.nn.conv2d(*copy_inputs, **new_attrs)
else: # pre-compute weight transformation in winograd
if "-device=arm_cpu" in target.options:
tile_size = 4
VC = cfg['tile_k'].size[-1]
else:
from ..mali.conv2d import _pick_tile_size
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
VC = cfg['tile_bna'].val

weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size)
weight = F.reshape(weight,
newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])

copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size

# Store the same config for the altered operator (workload)
new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation,
new_attrs[data_layout_key], out_dtype, tile_size],
conv2d_winograd_without_weight_transform)
dispatch_ctx.update(target, new_workload, cfg)

return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
else:
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
target = tvm.target.current_target()
dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload)

if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload)
return None
if cfg.template_key == 'contrib_spatial_pack':
VC = cfg['tile_co'].size[-1]
new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])

# Store the same config for the altered operator (workload)
new_data = data
CO, M, KH, KW = get_const_tuple(kernel.shape)
new_kernel = tvm.placeholder((CO // VC, M, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, out_dtype],
depthwise_conv2d_nchw)
dispatch_ctx.update(target, new_workload, cfg)

return F.nn.conv2d(*copy_inputs, **new_attrs)
else:
from ..mali.conv2d import _pick_tile_size
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
VC = cfg['tile_bna'].val

weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size)
weight = F.reshape(weight,
newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])

copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size

# Store the same config for the altered operator (workload)
new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation,
new_attrs[data_layout_key], out_dtype, tile_size],
conv2d_winograd_without_weight_transform)
dispatch_ctx.update(target, new_workload, cfg)

return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
# currently we only have contrib_spatial_pack and direct template
# add more schedule templates.
return None
Loading