Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""Conv2D schedule for ARM CPU"""
from __future__ import absolute_import as _abs

import warnings

import numpy as np

import tvm
Expand Down Expand Up @@ -522,7 +524,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
out_dtype = attrs["out_dtype"]
out_dtype = tinfos[0].dtype if out_dtype == "same" else out_dtype

if layout != 'NCHW' or groups != 1 or dilation != (1, 1):
if layout != 'NCHW' or groups != 1:
return None
if dilation != (1, 1):
warnings.warn("Does not support weight pre-transform for dilated convolution.")
return None

data, kernel = tinfos[0:2]
Expand All @@ -531,7 +536,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):

# query config of this workload
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, layout, out_dtype], conv2d)
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
target = tvm.target.current_target()
dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload)
Expand All @@ -548,7 +553,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
new_data = data
new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, 'NCHW', out_dtype], conv2d)
[new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
dispatch_ctx.update(target, new_workload, cfg)

return sym.conv2d(*copy_inputs, **new_attrs)
Expand All @@ -574,7 +579,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, new_attrs['layout'], out_dtype, tile_size],
[new_data, new_weight, strides, padding, dilation,
new_attrs['layout'], out_dtype, tile_size],
conv2d_winograd_without_weight_transform)
dispatch_ctx.update(target, new_workload, cfg)

Expand Down
5 changes: 5 additions & 0 deletions topi/python/topi/cuda/conv2d_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
new_attrs['out_layout'] = new_layout
new_attrs['kernel_layout'] = 'OIHW4o4i'
ic_block_factor = oc_block_factor = 4

# Store the same config for the altered operator (workload)
new_data = tvm.placeholder((N, CI // ic_block_factor, H, W, ic_block_factor),
dtype=data.dtype)
new_kernel = tvm.placeholder((CO // oc_block_factor, CI // ic_block_factor, KH, KW,\
Expand All @@ -387,7 +389,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
return sym.conv2d(*copy_inputs, **new_attrs)

if attrs.get_int_tuple("dilation") != (1, 1):
warnings.warn("Does not support weight pre-transform for dilated convolution.")
return None

# pre-compute weight transformation in winograd
tile_size = _infer_tile_size(tinfos[0], tinfos[1])

Expand All @@ -397,6 +401,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size

# Store the same config for the altered operator (workload)
new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO),
dtype=kernel.dtype)
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/mali/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def _schedule_winograd(cfg, s, op):

##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'mali', ['winograd'])
def conv2d_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
def conv2d_winograd_ww(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
"""TOPI compute callback"""
return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
tile_size)
Expand Down