Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/autotvm/tophub.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@

# the version of each package
PACKAGE_VERSION = {
'arm_cpu': "v0.04",
'arm_cpu': "v0.05",
'llvm': "v0.03",

'cuda': "v0.04",
'cuda': "v0.05",
'rocm': "v0.02",
'opencl': "v0.02",
'mali': "v0.05",
'opencl': "v0.03",
'mali': "v0.06",

'vta': "v0.06",
}
Expand Down
61 changes: 38 additions & 23 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
conv2d_winograd_nnpack_without_weight_transform, \
depthwise_conv2d_nchw
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices
from ..nn.winograd_util import winograd_transform_matrices, enum_tile_sizes

@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
Expand Down Expand Up @@ -302,11 +302,14 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd'])
def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
""" TOPI compute callback. Use winograd template """
tile_size = 4
return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout,
out_dtype, tile_size)
return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)

def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
tile_size=None):

cfg.define_knob('tile_size', enum_tile_sizes(data))
tile_size = tile_size if tile_size else cfg["tile_size"].val

def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
N, CI, IH, IW = get_const_tuple(data.shape)

if isinstance(dilation, int):
Expand All @@ -326,11 +329,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
CO *= VC
KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))

assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
assert KH == KW and HSTR == 1 and WSTR == 1
data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right])

r = KW
m = tile_size
Expand All @@ -340,8 +346,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
K = CO
C = CI

H = (IH + 2 * HPAD - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1
H = (IH + pad_top + pad_bottom - KH) // HSTR + 1
W = (IW + pad_left + pad_right - KW) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW

Expand Down Expand Up @@ -510,12 +516,15 @@ def conv2d_arm_cpu_winograd_nnpack(
assert len(kernel.shape) == 4
CO, _, KH, KW = get_const_tuple(kernel.shape)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))

assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1
H = (IH + 2 * HPAD - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
H = (IH + pad_top + pad_bottom - KH) // HSTR + 1
W = (IW + pad_left + pad_right - KW) // WSTR + 1

cfg.define_knob('winograd_nnpack_algorithm', [convolution_algorithm])

Expand All @@ -530,7 +539,7 @@ def conv2d_arm_cpu_winograd_nnpack(
output = tvm.contrib.nnpack.convolution_inference_without_weight_transform(
data, transformed_kernel,
bias=None,
padding=[HPAD, HPAD, WPAD, WPAD],
padding=[pad_top, pad_bottom, pad_left, pad_right],
stride=[HSTR, WSTR],
algorithm=cfg['winograd_nnpack_algorithm'].val)

Expand Down Expand Up @@ -590,21 +599,24 @@ def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides,
assert len(transformed_kernel.shape) == 4
CO, _, _, _ = get_const_tuple(transformed_kernel.shape)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, (3, 3))
KH, KW = 3, 3
dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))

assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1
H = (IH + 2 * HPAD - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
H = (IH + pad_top + pad_bottom - KH) // HSTR + 1
W = (IW + pad_left + pad_right - KW) // WSTR + 1

assert N == 1
with tvm.tag_scope("winograd_nnpack_conv2d_output"):
output = tvm.contrib.nnpack.convolution_inference_without_weight_transform(
data=data,
transformed_kernel=transformed_kernel,
bias=bias,
padding=[HPAD, HPAD, WPAD, WPAD],
padding=[pad_top, pad_bottom, pad_left, pad_right],
stride=[HSTR, WSTR],
algorithm=cfg['winograd_nnpack_algorithm'].val)

Expand Down Expand Up @@ -701,14 +713,16 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
dispatch_ctx.update(target, new_workload, cfg)

return F.nn.conv2d(*copy_inputs, **new_attrs)

elif cfg.template_key == "winograd": # pre-compute weight transformation in winograd
# safe default tile_size 2,4
# 4 (if winograd polynomial not excessive)
if "-device=arm_cpu" in target.options:
tile_size = 4
VC = cfg['tile_k'].size[-1]
tile_size = 4 if KH < 4 else 2
else:
from ..mali.conv2d import _pick_tile_size
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
VC = cfg['tile_bna'].val
tile_size = 4 if (H % 4 == 0) and (KH < 4) else 2

weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size)
Expand All @@ -730,6 +744,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
dispatch_ctx.update(target, new_workload, cfg)

return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)

elif cfg.template_key in ["winograd_nnpack_fp16", "winograd_nnpack_fp32"]:
# pre-compute winograd_nnpack transform
# for winograd_nnpack_fp16, the the precomputeprune pass must run on device,
Expand Down
76 changes: 38 additions & 38 deletions topi/python/topi/cuda/conv2d_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,25 @@
import tvm
from tvm import autotvm

from .. import nn
from ..nn import conv2d, group_conv2d_nchw, conv2d_winograd_without_weight_transform
from ..nn import pad, conv2d, conv2d_alter_layout, group_conv2d_nchw, \
conv2d_winograd_without_weight_transform
from ..util import get_const_int, get_const_tuple, traverse_inline
from ..nn.util import get_pad_tuple
from ..generic import schedule_conv2d_winograd_without_weight_transform
from ..nn.winograd_util import winograd_transform_matrices
from ..nn.winograd_util import winograd_transform_matrices, enum_tile_sizes


def _infer_tile_size(data, kernel):
N, CI, H, W = get_const_tuple(data.shape)

if H % 8 == 0:
return 4
return 2

def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, pre_computed):
def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size=None,
pre_computed=False):
"""Compute declaration for winograd"""
assert layout == 'NCHW'

tile_size = _infer_tile_size(data, kernel)
cfg.define_knob('tile_size', enum_tile_sizes(data))
tile_size = tile_size if tile_size else cfg["tile_size"].val

N, CI, H, W = get_const_tuple(data.shape)


if not pre_computed: # kernel tensor is raw tensor, do strict check
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
Expand All @@ -51,18 +48,18 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
kernel = dilate(kernel, (1, 1, dilation_h, dilation_w))

CO, CI, KH, KW = get_const_tuple(kernel.shape)
HPAD, WPAD, _, _ = nn.get_pad_tuple(padding, kernel)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides
assert HSTR == 1 and WSTR == 1 and KH == KW
else: # kernel tensor is pre-transfomred. this op is created by
# alter op layout, do not check
else: # kernel tensor is pre-transformed.
# this op is created by alter op layout, do not check
# dilation is not supported
HSTR = WSTR = 1
HPAD = WPAD = 1
KH = KW = 3
_, _, CI, CO = get_const_tuple(kernel.shape)

data_pad = nn.pad(data, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad")
data_pad = pad(data, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad")

r = KW
m = tile_size
Expand Down Expand Up @@ -106,7 +103,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty

# inverse transform
r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_b')
inverse = tvm.compute((CO, P, m, m), lambda co, p, vh, vw:
tvm.sum(bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw],
axis=[r_a, r_b]), name='inverse')
Expand Down Expand Up @@ -280,7 +277,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
['cuda', 'gpu'], ['winograd'])
def conv2d_winograd_ww(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
pre_computed=True)
tile_size, pre_computed=True)


@autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
Expand All @@ -298,7 +295,7 @@ def _callback(op):


##### REGISTER ALTER OP LAYOUT #####
@nn.conv2d_alter_layout.register(["cuda", "gpu"])
@conv2d_alter_layout.register(["cuda", "gpu"])
def _alter_conv2d_layout(attrs, inputs, tinfos, F):
"""Alter op layout for pre-computing kernel transformation

Expand Down Expand Up @@ -382,25 +379,28 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
warnings.warn("Does not support weight pre-transform for dilated convolution.")
return None

# pre-compute weight transformation in winograd
tile_size = _infer_tile_size(tinfos[0], tinfos[1])

weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size)
weight = F.transpose(weight, axes=[0, 1, 3, 2])
copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size

# Store the same config for the altered operator (workload)
new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation, layout, out_dtype, tile_size],
conv2d_winograd_without_weight_transform
)
dispatch_ctx.update(target, new_workload, cfg)
return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
if cfg.template_key == "winograd":
# safe default tile_size
# 4 (if winograd polynomial not excessive)
tile_size = 4 if (H % 8 == 0) and (KH < 4) else 2
new_attrs['tile_size'] = tile_size

weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size)
weight = F.transpose(weight, axes=[0, 1, 3, 2])
copy_inputs[1] = weight

# Store the same config for the altered operator (workload)
new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation, layout, out_dtype, tile_size],
conv2d_winograd_without_weight_transform
)
dispatch_ctx.update(target, new_workload, cfg)
return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)

if groups != CI:
workload = autotvm.task.args_to_workload(
[tinfos[0], tinfos[1], strides, padding, dilation, groups, out_dtype],
Expand Down
37 changes: 19 additions & 18 deletions topi/python/topi/mali/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..util import traverse_inline, get_const_int, get_const_tuple
from ..nn import conv2d, conv2d_winograd_without_weight_transform, \
get_pad_tuple, pad, conv2d_alter_layout
from ..nn.winograd_util import winograd_transform_matrices
from ..nn.winograd_util import winograd_transform_matrices, enum_tile_sizes

# reuse some compute declarations from ARM CPU
from ..arm_cpu.conv2d import _decl_spatial_pack, _alter_conv2d_layout_arm
Expand Down Expand Up @@ -188,21 +188,16 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
return s

##### WINOGRAD TEMPLATE #####
def _pick_tile_size(data, kernel):
N, CI, H, W = get_const_tuple(data.shape)

if H % 4 == 0:
return 4
else:
return 2

@autotvm.register_topi_compute(conv2d, 'mali', ['winograd'])
def conv2d_mali_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
tile_size = _pick_tile_size(data, kernel)
return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
tile_size)
return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)

def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
tile_size=None):

cfg.define_knob('tile_size', enum_tile_sizes(data))
tile_size = tile_size if tile_size else cfg["tile_size"].val

def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
N, CI, IH, IW = get_const_tuple(data.shape)
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
Expand All @@ -222,19 +217,22 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
CO *= VC
KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))

assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
assert KH == KW and HSTR == 1 and WSTR == 1
data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right])

r = KW
m = tile_size
alpha = m + r - 1
A, B, G = winograd_transform_matrices(m, r, out_dtype)

H = (IH + 2 * HPAD - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1
H = (IH + pad_top + pad_bottom - KH) // HSTR + 1
W = (IW + pad_left + pad_right - KW) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW

Expand All @@ -250,6 +248,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
##### space definition end #####

if cfg.is_fallback:
# safe default tile_size always 2,4
# 4 (if winograd polynomial not excessive)
cfg['tile_size'].val = 4 if (IH % 4 == 0) and (KH < 4) else 2
cfg['tile_bnb'].val = 4
cfg['tile_bna'].val = 4
while CO % cfg['tile_bna'].val != 0:
Expand Down
Loading