Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,27 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert data.dtype == kernel.dtype, \
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)

def _get_workload_int8(data, kernel, stride, padding, out_dtype):
""" Get the workload structure. """
_, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)



@tvm.target.generic_func
def _get_alter_layout_schedule(wkl):
# pylint: disable=unreachable
Expand Down Expand Up @@ -118,6 +133,17 @@ def _get_schedule_NCHWc(wkl, layout, out_layout):
return wkl


@tvm.target.generic_func
def _get_schedule_NCHWc_int8(wkl, layout, out_layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl


def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
"""Convolution operator in NCHW layout.

Expand Down
12 changes: 12 additions & 0 deletions topi/python/topi/x86/check_targets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument
"""Checks different x86 targets for target specific schedules"""

def check_skylake(target):
"""
Checks if the target is skylake
"""

for opt in target.options:
if opt == '-mcpu=skylake-avx512':
return True
return False
145 changes: 131 additions & 14 deletions topi/python/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from .. import nn
from ..nn.util import infer_pad, infer_stride
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \
_get_workload, _get_schedule, _get_schedule_NCHWc, \
_get_alter_layout_schedule, Workload
_get_workload, _get_workload_int8, _get_schedule, _get_schedule_NCHWc, \
_get_schedule_NCHWc_int8, _get_alter_layout_schedule, Workload

from . import conv2d_avx_1x1, conv2d_avx_common
from .conv2d_avx_common import AVXConvCommonFwd
from .conv2d_avx_1x1 import AVXConv1x1Fwd
from .check_targets import check_skylake

@_get_schedule.register("cpu")
def _get_schedule_conv(wkl):
Expand Down Expand Up @@ -100,10 +101,95 @@ def _get_schedule_conv(wkl):
sch = _SCHEDULES_AVX[idx]
return sch

def _get_schedule_conv_int8(wkl):
_WORKLOADS_AVX = [
## Following are for INT8 kernels
Workload('uint8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('uint8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('uint8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('uint8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('uint8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('uint8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('uint8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
# workloads of resnet34_v1 on imagenet, no extra workload required
# workloads of resnet50_v1 on imagenet
Workload('uint8', 'int32', 56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 7, 7, 2048, 512, 1, 1, 0, 0, 1, 1),
]

fp32_vec_len = 8
target = tvm.target.current_target(allow_none=False)
if check_skylake(target):
fp32_vec_len = 16

_SCHEDULES_AVX = [
# Following are for INT8 operations
# workloads of resnet18_v1 on imagenet
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, True),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 7),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
# workloads of resnet34_v1 on imagenet, no extra workload required
# workloads of resnet50_v1 on imagenet
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
# workloads of resnet101_v1 on imagenet, no extra workload required
# workloads of resnet152_v1 on imagenet, no extra workload required
# workloads of resnet18_v2 on imagenet, no extra workload required
# workloads of resnet34_v2 on imagenet, no extra workload required
]

if wkl not in _WORKLOADS_AVX:
if wkl.hkernel == 1 and wkl.wkernel == 1:
return conv2d_avx_1x1._get_default_schedule(wkl, fp32_vec_len)
return conv2d_avx_common._get_default_schedule(wkl, fp32_vec_len)
idx = _WORKLOADS_AVX.index(wkl)
sch = _SCHEDULES_AVX[idx]
return sch

@_get_schedule_NCHWc.register("cpu")
def _get_schedule_NCHWc_x86(wkl, layout, out_layout):
return _get_schedule_conv(wkl)

@_get_schedule_NCHWc_int8.register("cpu")
def _get_schedule_NCHWc_x86_int8(wkl, layout, out_layout):
return _get_schedule_conv_int8(wkl)

@_get_alter_layout_schedule.register("cpu")
def _get_alter_layout_schedule_x86(wkl):
return _get_schedule_conv(wkl)
Expand Down Expand Up @@ -162,20 +248,37 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)



@conv2d_NCHWc.register("cpu")
def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride,
padding, layout, out_layout, out_dtype):
_AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc
}

# Use int8 schedules if the input data is of int8 dtype
if data.dtype == 'uint8':
_AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc_int8,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc_int8
}

n, ic_chunk, h, w, ic_block = [x.value for x in data.shape]
ic = ic_chunk * ic_block
kh, kw = kernel_size
wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=out_dtype),
tvm.placeholder((num_filter, ic, kh, kw), dtype=out_dtype),
stride, padding, out_dtype)
sch = _get_schedule_NCHWc(wkl, layout, out_layout)
if data.dtype == 'uint8':
wkl = _get_workload_int8(tvm.placeholder((n, ic, h, w), dtype=data.dtype),
tvm.placeholder((num_filter, ic, kh, kw),
dtype=kernel.dtype),
stride, padding, out_dtype)
sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout)
else:
wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=data.dtype),
tvm.placeholder((num_filter, ic, kh, kw),
dtype=kernel.dtype),
stride, padding, out_dtype)
sch = _get_schedule_NCHWc(wkl, layout, out_layout)
return _AVX_SCH_TO_DECL_FUNC[type(sch)](wkl, sch, data, kernel)


Expand Down Expand Up @@ -289,10 +392,6 @@ def traverse(op):
def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
layout, out_layout, outs):
"""Create schedule for tensors"""
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc
}
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []

Expand All @@ -317,15 +416,33 @@ def traverse(op):
data_pad = data
data = data_pad.op.input_tensors[0]

_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc
}

# Use int8 schedules if the input data is of int8 dtype
if data.dtype == 'uint8':
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc_int8,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc_int8
}

n, ic_chunk, h, w, ic_block = [x.value for x in data.shape]
ic = ic_chunk * ic_block
original_data = tvm.placeholder((n, ic, h, w), dtype=conv_out.dtype)
original_data = tvm.placeholder((n, ic, h, w), dtype=data.dtype)

kh, kw = kernel_size
original_kernel = tvm.placeholder((num_filter, ic, kh, kw), dtype=conv_out.dtype)
original_kernel = tvm.placeholder((num_filter, ic, kh, kw),
dtype=kernel.dtype)

wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype)
sch = _get_schedule_NCHWc(wkl, layout, out_layout)
if data.dtype == 'uint8':
wkl = _get_workload_int8(original_data, original_kernel,
stride, padding, conv_out.dtype)
sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout)
else:
wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype)
sch = _get_schedule_NCHWc(wkl, layout, out_layout)
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec,
kernel, conv_out, outs[0])

Expand Down
117 changes: 117 additions & 0 deletions topi/python/topi/x86/conv2d_avx_1x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
import topi

from ..util import get_const_tuple
from ..nn.conv2d import _get_schedule, _get_workload
from ..nn.util import infer_pad, infer_stride
from ..nn.pad import pad
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake

AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor'])

Expand Down Expand Up @@ -229,3 +232,117 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last):
s[O].parallel(parallel_axis)

return s


def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel):
""" Declaration for int8 conv"""
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride

batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1

DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data

oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)

# Intel performs dot product of 2 "4" Int8 values
n_elems = 4
assert sch.ic_bn%n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')

# Reshaping kernel as the last 2 dimensions are 1x1 (k_h x k_w)
k_shape = kernel.shape
kernel = topi.reshape(kernel, (k_shape[0], k_shape[1], k_shape[2], k_shape[3],
k_shape[4] * k_shape[5] * k_shape[6]))

conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR, ow*WSTR,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[oc_chunk, ic_outer, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8',
tag="conv2d_NCHWc_int8")


return conv


def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
"""
Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""

target = tvm.target.current_target(allow_none=False)
int32_lanes = -1
if check_skylake(target):
int32_lanes = 16
else:
return s
assert int32_lanes != -1

# schedule data
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
parallel_axis = s[A].fuse(ic_chunk, ih)
s[A].parallel(parallel_axis)

C, O = conv_out, last
CC = s.cache_write(C, 'global')

batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor)
s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
s[C].vectorize(oc_block)

parallel_axis = s[C].fuse(oc_chunk, oh_outer)
s[CC].compute_at(s[C], parallel_axis)
if C == O:
s[C].parallel(parallel_axis)

_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis

# Skylake and future processors have 16 vector lanes
assert sch.oc_bn % int32_lanes == 0

oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)

oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor)

s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_outer, ic_f_inner, oh_inner,
ow_inner, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].fuse(oc_chunk, oh_outer)

pc = dot_16x1x16_int8_int8_int32()
s[CC].tensorize(oc_s_inner, pc)
s[CC].unroll(ow_inner)
s[CC].unroll(oh_inner)

if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)

parallel_axis = s[O].fuse(oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)

return s
Loading