diff --git a/python/tvm/autotvm/__init__.py b/python/tvm/autotvm/__init__.py index 7170dbdd8565..08cfbb2a95da 100644 --- a/python/tvm/autotvm/__init__.py +++ b/python/tvm/autotvm/__init__.py @@ -27,5 +27,6 @@ from .tuner import callback from .task import template, get_config, create, ConfigSpace, ConfigEntity, \ register_topi_compute, register_topi_schedule, \ - DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best + DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \ + ApplyGraphBest as apply_graph_best from .env import GLOBAL_SCOPE diff --git a/python/tvm/autotvm/task/__init__.py b/python/tvm/autotvm/task/__init__.py index 8efb0e61b518..04bcec92fd57 100644 --- a/python/tvm/autotvm/task/__init__.py +++ b/python/tvm/autotvm/task/__init__.py @@ -10,7 +10,7 @@ from .space import ConfigSpace, ConfigEntity from .code_hash import attach_code_hash, attach_code_hash_to_arg from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \ - FallbackContext, clear_fallback_cache + FallbackContext, clear_fallback_cache, ApplyGraphBest from .topi_integration import register_topi_compute, register_topi_schedule from .nnvm_integration import extract_from_graph, extract_from_multiple_graph diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index 8e159cc412c9..164877e3b451 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -345,3 +345,83 @@ def clear_fallback_cache(target, workload): while not isinstance(context, FallbackContext): context = context._old_ctx context.clear_cache(target, workload) + +class ApplyGraphBest(DispatchContext): + """Load the graph level tuning optimal schedules. + + The input records should be in the ascending order of + node index for target operator. Usually this can be obtained + with graph tuner. + + This context maintains an internal counter to indicate the current + node index. + """ + def __init__(self, records): + """ + Parameters + ---------- + records : str or iterator of (MeasureInput, MeasureResult) + Collection of tuning records. + If is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. + Otherwise, it is an iterator. + """ + from ..record import load_from_file + + super(ApplyGraphBest, self).__init__() + if isinstance(records, str): + records = load_from_file(records) + self._records = list(records) + self._counter = 0 + self._global_cfg_dict = {} + + def _query_inside(self, target, workload): + """ + Query the context to get config from records. + + Parameters + ---------- + target : Target + The current target + workload : Workload + The current workload. + + Returns + ------- + cfg : ConfigSpace + The specific configuration. + """ + cfg = self._records[self._counter][0].config + self._counter += 1 + return cfg + + def query_global_dict(self, key): + """ + Query the context to get config from global + config dictionary. + + Parameters + ---------- + key : str + Key to query the config. + + Returns + ------- + cfg : ConfigSpace + The specific configuration. + """ + return self._global_cfg_dict[key] + + def update_global_dict(self, key, val): + """ + Update the global config dictionary. + + Parameters + ---------- + key : str + Key of config. + + val : ConfigSpace + Value of config. + """ + self._global_cfg_dict[key] = val diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 6fe59a909510..f766d827686d 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -1,12 +1,17 @@ -# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member """Conv2D schedule on x86""" import tvm +from tvm import autotvm +from tvm.autotvm.task.dispatcher import ApplyGraphBest +from tvm.autotvm.task.nnvm_integration import deserialize_args +from tvm.autotvm.task import register, get_config from .. import generic, tag from .. import nn -from ..nn.util import infer_pad, infer_stride +from ..util import get_const_tuple from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \ - _get_workload, _get_workload_int8, _get_schedule, _get_schedule_NCHWc, \ + _get_workload_int8, _get_schedule, _get_schedule_NCHWc, \ _get_schedule_NCHWc_int8, _get_alter_layout_schedule, Workload +from ..nn.pad import pad from . import conv2d_avx_1x1, conv2d_avx_common from .conv2d_avx_common import AVXConvCommonFwd @@ -194,103 +199,164 @@ def _get_schedule_NCHWc_x86_int8(wkl, layout, out_layout): def _get_alter_layout_schedule_x86(wkl): return _get_schedule_conv(wkl) -@conv2d.register("cpu") -def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): - _AVX_SCH_TO_DECL_FUNC = { - AVXConvCommonFwd: conv2d_avx_common._declaration_conv, - AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv - } - out_dtype = data.dtype if out_dtype is None else out_dtype - target = tvm.target.current_target(allow_none=False) - wkl = _get_workload(data, kernel, stride, padding, out_dtype) - if layout == 'NCHW': - sch = _get_schedule(wkl) - return _AVX_SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, layout, out_dtype) - elif layout == 'HWCN': - return nn.conv2d_hwcn(data, kernel, stride, padding, out_dtype) - elif layout == 'NHWC': - return nn.conv2d_nhwc(data, kernel, stride, padding, out_dtype) + +def _get_fp32_len(): + fp32_vec_len = 8 + target = tvm.target.current_target() + if target is not None: + for opt in target.options: + if opt == '-mcpu=skylake-avx512': + fp32_vec_len = 16 + return fp32_vec_len + + +def _get_default_sch(workload): + fp32_vec_len = _get_fp32_len() + _, _, kh, kw, _ = workload[2] + is_kernel_1x1 = kh == 1 and kw == 1 + if is_kernel_1x1: + cfg = conv2d_avx_1x1._fallback_schedule(workload, fp32_vec_len) else: - raise ValueError("not support this layout {} yet".format(layout)) + cfg = conv2d_avx_common._fallback_schedule(workload, fp32_vec_len) + return cfg -@conv2d_alter_layout.register("cpu") -def _alter_conv2d_layout(attrs, inputs, tinfos): - import nnvm.symbol as sym - copy_inputs = [s for s in inputs] - new_attrs = {k : attrs[k] for k in attrs.keys()} - # only optimize for NCHW, groups=1 conv - if attrs['layout'] != 'NCHW' or attrs.get_int("groups") != 1: - return None +def _create_schedule_template(cfg, data, kernel, strides, padding, layout): + """Create schedule configuration from input arguments""" + dshape = get_const_tuple(data.shape) + kshape = get_const_tuple(kernel.shape) + if layout == 'NCHW': + n, ic, h, w = dshape + oc, _, kh, kw = kshape + else: + raise ValueError("Not support this layout {} with " + "schedule template.".format(layout)) + is_kernel_1x1 = kh == 1 and kw == 1 + ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) + sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) + oh = (h - kh + 2 * ph) // sh + 1 + ow = (w - kw + 2 * pw) // sw + 1 + + # Create schedule config + cfg.define_split("tile_ic", ic, num_outputs=2) + cfg.define_split("tile_oc", oc, num_outputs=2) + cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64) + if is_kernel_1x1: + cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1]) + else: + cfg.define_knob("unroll_kw", [True, False]) - data = tinfos[0] - kernel = tinfos[1] - import ast - padding = ast.literal_eval(attrs['padding']) - stride = ast.literal_eval(attrs['strides']) +def conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype): + """convert argument to workload""" + if len(kernel.shape) == 4: + raw_kernel = kernel + else: # the input kernel is transformed by alter_op_layout + shape = get_const_tuple(kernel.shape) + raw_kernel = tvm.placeholder((shape[0] * shape[4], shape[1], shape[2], shape[3]), + dtype=kernel.dtype) + return ('conv2d', ) + autotvm.task.args_to_workload( + [data, raw_kernel, strides, padding, layout, out_dtype]) - wkl = _get_workload(data, kernel, stride, padding, data.dtype) - sch = _get_alter_layout_schedule(wkl) - is_kernel_1x1 = isinstance(sch, AVXConv1x1Fwd) - ic_bn, oc_bn = sch.ic_bn, sch.oc_bn - new_attrs['layout'] = 'NCHW%dc' % ic_bn - new_attrs['out_layout'] = 'NCHW%dc' % oc_bn +@conv2d.register("cpu") +@autotvm.task.dispatcher +def conv2d_x86(data, kernel, strides, padding, layout, out_dtype): + """x86 conv2d declaration.""" + return conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype) - if is_kernel_1x1: - # (oc, ic, h, w) -> (OC, IC, ic, oc, h, w) - new_attrs['kernel_layout'] = 'OI%di%doHW' % (ic_bn, oc_bn) + +@conv2d_x86.register(["direct"]) +def _declaration_conv(cfg, data, kernel, strides, padding, layout, out_dtype): + out_dtype = data.dtype if out_dtype is None else out_dtype + padding = padding if isinstance(padding, (tuple, list)) else (padding, padding) + strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) + if layout == 'NCHW': + _create_schedule_template(cfg, data, kernel, strides, padding, layout) + if cfg.is_fallback: + workload = conv_arg_to_workload(data, kernel, strides, padding, + layout, out_dtype) + cfg = _get_default_sch(workload) + args = [cfg, data, kernel, strides, padding, layout, out_dtype] + return _declaration_conv_impl(*args) + elif layout == 'HWCN': + return nn.conv2d_hwcn(data, kernel, strides, padding, out_dtype) + elif layout == 'NHWC': + return nn.conv2d_nhwc(data, kernel, strides, padding, out_dtype) else: - # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) - new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + raise ValueError("not support this layout {} yet".format(layout)) - return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) +def _declaration_conv_impl(cfg, data, kernel, strides, padding, layout, out_dtype): + out_dtype = data.dtype if out_dtype is None else out_dtype + assert layout == 'NCHW', "only support NCHW convolution for AVX" + HPAD, WPAD = padding + HSTR, WSTR = strides -@conv2d_NCHWc.register("cpu") -def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride, - padding, layout, out_layout, out_dtype): - _AVX_SCH_TO_DECL_FUNC = { - AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc, - AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc - } + batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape) + num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape) - # Use int8 schedules if the input data is of int8 dtype - if data.dtype == 'uint8': - _AVX_SCH_TO_DECL_FUNC = { - AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc_int8, - AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc_int8 - } + pad_height = in_height + 2 * HPAD + pad_width = in_width + 2 * WPAD - n, ic_chunk, h, w, ic_block = [x.value for x in data.shape] - ic = ic_chunk * ic_block - kh, kw = kernel_size - if data.dtype == 'uint8': - wkl = _get_workload_int8(tvm.placeholder((n, ic, h, w), dtype=data.dtype), - tvm.placeholder((num_filter, ic, kh, kw), - dtype=kernel.dtype), - stride, padding, out_dtype) - sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout) - else: - wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=data.dtype), - tvm.placeholder((num_filter, ic, kh, kw), - dtype=kernel.dtype), - stride, padding, out_dtype) - sch = _get_schedule_NCHWc(wkl, layout, out_layout) - return _AVX_SCH_TO_DECL_FUNC[type(sch)](wkl, sch, data, kernel) + out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1 + out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1 - -@generic.schedule_conv2d_nchw.register(["cpu"]) -def schedule_conv2d(outs): + # pack data + DOPAD = (HPAD != 0 or WPAD != 0) + if DOPAD: + data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") + else: + data_pad = data + + # fetch schedule + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + + shape = (batch_size, in_channel // ic_bn, pad_height, ic_bn, pad_width) + data_vec = tvm.compute(shape, + lambda n, C, h, c, w: data_pad[n, C * ic_bn + c, h, w], + name='data_vec') + + # pack kernel + shape = (num_filter//oc_bn, in_channel//ic_bn, + kernel_height, kernel_width, ic_bn, oc_bn) + kernel_vec = tvm.compute(shape, + lambda CO, CI, h, w, ci, co: + kernel[CO * oc_bn + co, CI * ic_bn + ci, h, w], + name='kernel_vec') + + # convolution + oshape = (batch_size, num_filter//oc_bn, out_height, out_width, oc_bn) + unpack_shape = (batch_size, num_filter, out_height, out_width) + + ic = tvm.reduce_axis((0, in_channel), name='ic') + kh = tvm.reduce_axis((0, kernel_height), name='kh') + kw = tvm.reduce_axis((0, kernel_width), name='kw') + + conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: + tvm.sum(data_vec[n, ic//ic_bn, oh*HSTR+kh, ic%ic_bn, + ow*WSTR+kw].astype(out_dtype) * + kernel_vec[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, + oc_block].astype(out_dtype), + axis=[ic, kh, kw]), name='conv') + + unpack = tvm.compute(unpack_shape, + lambda n, c, h, w: conv[n, c // oc_bn, h, w, c % oc_bn] + .astype(out_dtype), + name='output_unpack', + tag='conv2d_nchw', + attrs={'workload': + conv_arg_to_workload(data, kernel, strides, + padding, layout, + out_dtype)}) + return unpack + + +@autotvm.task.register_topi_schedule(generic.schedule_conv2d_nchw, 'cpu', ['direct']) +def schedule_conv2d(cfg, outs): """Create schedule for tensors""" - _AVX_SCH_TO_SCH_FUNC = { - AVXConvCommonFwd: conv2d_avx_common._schedule_conv, - AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv - } s = tvm.create_schedule([x.op for x in outs]) - target = tvm.target.current_target(allow_none=False) scheduled_ops = [] def traverse(op): @@ -316,16 +382,25 @@ def traverse(op): if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: data_pad = data data = data_pad.op.input_tensors[0] - padding = infer_pad(data, data_pad) - if data_pad is None: - stride = infer_stride(data, kernel, output) - else: - stride = infer_stride(data_pad, kernel, output) - wkl = _get_workload(data, kernel, stride, padding, output.dtype) - sch = _get_schedule(wkl) - _AVX_SCH_TO_SCH_FUNC[type(sch)](s, data, data_pad, data_vec, - kernel, kernel_vec, conv_out, output, outs[0]) + _, _, kh, kw = get_const_tuple(kernel.shape) + is_kernel_1x1 = kh == 1 and kw == 1 + current_cfg = cfg + if cfg.is_fallback: + workload_attr = op.attrs["workload"] + strides = (int(workload_attr[3][0].value), int(workload_attr[3][1].value)) + padding = (int(workload_attr[4][0].value), int(workload_attr[4][1].value)) + layout = workload_attr[5].value + out_dtype = workload_attr[6].value + workload = conv_arg_to_workload(data, kernel, strides, padding, + layout, out_dtype) + current_cfg = _get_default_sch(workload) + args = [s, current_cfg, data, data_pad, data_vec, kernel_vec, conv_out, + output, outs[0]] + if is_kernel_1x1: + conv2d_avx_1x1._schedule_conv(*args) + else: + conv2d_avx_common._schedule_conv(*args) scheduled_ops.append(op) @@ -333,7 +408,7 @@ def traverse(op): return s -@generic.schedule_conv2d_nhwc.register(["cpu"]) +@generic.schedule_conv2d_nhwc.register("cpu") def schedule_conv2d_nhwc(outs): """Create schedule for tensors""" s = tvm.create_schedule([x.op for x in outs]) @@ -388,12 +463,223 @@ def traverse(op): return s -@generic.schedule_conv2d_NCHWc.register(["cpu"]) -def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, +# Define template function for autotvm task +# We define schedule template in this function instead of +# declaration function since actual input arguments need +# to be altered by the schedule selected. +@register("topi_x86_conv2d_NCHWc") +def _topi_nn_conv2d_NCHWc(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + data, kernel = args[:2] + strides = args[4] + padding = args[5] + layout = args[6] + raw_data_shape = get_const_tuple(data.shape) + raw_kernel_shape = get_const_tuple(kernel.shape) + + # get config here + cfg = get_config() + _create_schedule_template(cfg, data, kernel, strides, padding, layout) + + # change shape with the value in config + ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], + cfg["tile_ow"].size[-1]) + new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn, + raw_data_shape[2], raw_data_shape[3], ic_bn) + data_layout = "NCHW%dc" % ic_bn + out_layout = "NCHW%dc" % oc_bn + new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn, + raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn) + args[0] = tvm.placeholder(new_data_shape, data.dtype) + args[1] = tvm.placeholder(new_kernel_shape, kernel.dtype) + args[6] = data_layout + args[7] = out_layout + + C = _declaration_conv_NCHWc(cfg, *args, **kwargs) + s = _schedule_conv2d_NCHWc(cfg, args[2], args[3], args[4], args[5], + args[6], args[7], [C]) + return s, [args[0], args[1], C] + + +def conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides, + padding, layout, out_layout, out_dtype): + """convert argument to workload""" + dshape = get_const_tuple(data.shape) + kshape = get_const_tuple(kernel.shape) + if len(dshape) > 4: + raw_data = tvm.placeholder((dshape[0], dshape[1] * dshape[4], dshape[2], + dshape[3]), dtype=kernel.dtype) + else: + raw_data = data + if len(kshape) > 4: + raw_kernel = tvm.placeholder((kshape[0] * kshape[5], kshape[1] * kshape[4], + kshape[2], kshape[3]), dtype=kernel.dtype) + else: + raw_kernel = kernel + return ('conv2d_NCHWc', ) + autotvm.task.args_to_workload( + [raw_data, raw_kernel, strides, padding, layout, out_layout, + out_dtype]) + + +def _query_dispatcher(workload, in_alter_op=False): + dispatch_ctx = autotvm.task.DispatchContext.current + if isinstance(dispatch_ctx, ApplyGraphBest): + if in_alter_op: + cfg = dispatch_ctx.query(None, None) + else: + cfg = dispatch_ctx.query_global_dict(workload) + else: + target = tvm.target.current_target() + cfg = dispatch_ctx.query(target, workload) + if cfg.is_fallback: + cfg = _get_default_sch(workload) + return cfg + + +@conv2d_alter_layout.register("cpu") +def _alter_conv2d_layout(attrs, inputs, tinfo): + import nnvm.symbol as sym + copy_inputs = [s for s in inputs] + new_attrs = {k : attrs[k] for k in attrs.keys()} + data, kernel = tinfo[0], tinfo[1] + # only optimize for NCHW, groups=1 conv + if attrs['layout'] != 'NCHW' or attrs.get_int("groups") != 1: + return None + + kernel_size = attrs.get_int_tuple("kernel_size") + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + layout = attrs['layout'] + out_layout = layout if attrs["out_layout"] == "__undef__" else attrs["out_layout"] + + dtype = data.dtype + out_dtype = dtype if attrs["out_dtype"] == "same" else attrs["out_dtype"] + workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides, + padding, layout, out_layout, out_dtype) + cfg = _query_dispatcher(workload, True) + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + new_attrs['layout'] = 'NCHW%dc' % ic_bn + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn + + # Store global schedule dictionary for ApplyGraphBest dispatcher + dispatch_ctx = autotvm.task.DispatchContext.current + if isinstance(dispatch_ctx, ApplyGraphBest): + workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides, + padding, new_attrs['layout'], + new_attrs['out_layout'], out_dtype) + global_dict_key = workload + dispatch_ctx.update_global_dict(global_dict_key, cfg) + + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + + return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + + +@conv2d_NCHWc.register("cpu") +def conv2d_NCHWc_cpu(data, kernel, num_filter, kernel_size, strides, + padding, layout, out_layout, out_dtype): + """x86 conv2d_NCHWc declaration.""" + dispatch_ctx = autotvm.task.DispatchContext.current + if not isinstance(dispatch_ctx, ApplyGraphBest): + layout = out_layout = "NCHW" + workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides, + padding, layout, out_layout, out_dtype) + cfg = _query_dispatcher(workload) + return _declaration_conv_NCHWc(cfg, data, kernel, num_filter, kernel_size, strides, + padding, layout, out_layout, out_dtype) + + +def _declaration_conv_NCHWc(cfg, data, kernel, num_filter, kernel_size, strides, + padding, layout, out_layout, out_dtype): + n, ic_chunk, h, w, ic_block = [x.value for x in data.shape] + ic = ic_chunk * ic_block + kh, kw = kernel_size if isinstance(kernel_size, (tuple, list)) else \ + (kernel_size, kernel_size) + is_kernel_1x1 = kh == 1 and kw == 1 + ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) + sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) + + if data.dtype == 'uint8': + wkl = _get_workload_int8(tvm.placeholder((n, ic, h, w), dtype=data.dtype), + tvm.placeholder((num_filter, ic, kh, kw), + dtype=kernel.dtype), + strides, padding, out_dtype) + sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout) + return conv2d_avx_1x1._declaration_conv_NCHWc_int8(wkl, sch, data, kernel) \ + if is_kernel_1x1 \ + else conv2d_avx_common._declaration_conv_NCHWc_int8(wkl, sch, data, kernel) + + args = [cfg, data, kernel, (kh, kw), (sh, sw), (ph, pw), layout, out_layout, out_dtype] + return _declaration_conv_NCHWc_impl(*args) + + +def _declaration_conv_NCHWc_impl(cfg, data, kernel, kernel_size, strides, padding, layout, + out_layout, out_dtype): + HPAD, WPAD = padding + HSTR, WSTR = strides + + n, ic_chunk, ih, iw, ic_block = get_const_tuple(data.shape) + ic = ic_chunk * ic_block + kh, kw = kernel_size + oc_chunk, _, _, _, _, oc_block = get_const_tuple(kernel.shape) + oc = oc_chunk * oc_block + oh = (ih + 2 * HPAD - kh) // HSTR + 1 + ow = (iw + 2 * WPAD - kw) // WSTR + 1 + + # DOPAD + DOPAD = (HPAD != 0 or WPAD != 0) + if DOPAD: + data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") + else: + data_pad = data + + # fetch schedule + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + if ic_bn != ic_block: + raise RuntimeError("ic_bn in config is not equal to actual data ic_block: %d vs %d." + % (ic_bn, ic_block)) + if oc_bn != oc_block: + raise RuntimeError("oc_bn in config is not equal to actual kernel oc_block: %d vs %d." + % (oc_bn, oc_block)) + + # convolution + oshape = (n, oc//oc_bn, oh, ow, oc_bn) + + ic = tvm.reduce_axis((0, ic), name='ic') + kh = tvm.reduce_axis((0, kernel_size[0]), name='kh') + kw = tvm.reduce_axis((0, kernel_size[1]), name='kw') + + workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, + strides, padding, layout, + out_layout, out_dtype), + attrs = {'workload': workload} + conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: + tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, + ic%ic_bn].astype(out_dtype) * + kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block], + axis=[ic, kh, kw]), + name='conv2d_NCHWc', tag="conv2d_NCHWc", attrs=attrs) + return conv + + +@generic.schedule_conv2d_NCHWc.register("cpu") +def schedule_conv2d_NCHWc(num_filter, kernel_size, strides, padding, layout, out_layout, outs): + """x86 conv2d_NCHWc schedule""" + return _schedule_conv2d_NCHWc(None, num_filter, kernel_size, strides, padding, + layout, out_layout, outs) + + +def _schedule_conv2d_NCHWc(cfg, num_filter, kernel_size, strides, padding, + layout, out_layout, outs): """Create schedule for tensors""" s = tvm.create_schedule([x.op for x in outs]) scheduled_ops = [] + dispatch_ctx = autotvm.task.DispatchContext.current + if not isinstance(dispatch_ctx, ApplyGraphBest): + layout = out_layout = "NCHW" def traverse(op): """Traverse operators from computation graph""" @@ -416,18 +702,9 @@ def traverse(op): data_pad = data data = data_pad.op.input_tensors[0] - _AVX_SCH_TO_SCH_FUNC = { - AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc, - AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc - } - - # Use int8 schedules if the input data is of int8 dtype - if data.dtype == 'uint8': - _AVX_SCH_TO_SCH_FUNC = { - AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc_int8, - AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc_int8 - } - + kh, kw = kernel_size if isinstance(kernel_size, (tuple, list)) else \ + (kernel_size, kernel_size) + is_kernel_1x1 = kh == 1 and kw == 1 n, ic_chunk, h, w, ic_block = [x.value for x in data.shape] ic = ic_chunk * ic_block original_data = tvm.placeholder((n, ic, h, w), dtype=data.dtype) @@ -435,16 +712,27 @@ def traverse(op): kh, kw = kernel_size original_kernel = tvm.placeholder((num_filter, ic, kh, kw), dtype=kernel.dtype) - if data.dtype == 'uint8': wkl = _get_workload_int8(original_data, original_kernel, - stride, padding, conv_out.dtype) + strides, padding, conv_out.dtype) sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout) + args = [s, wkl, sch, data_vec, kernel, conv_out, outs[0]] + if is_kernel_1x1: + conv2d_avx_1x1._schedule_conv_NCHWc_int8(*args) + else: + conv2d_avx_common._schedule_conv_NCHWc_int8(*args) else: - wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype) - sch = _get_schedule_NCHWc(wkl, layout, out_layout) - _AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec, - kernel, conv_out, outs[0]) + current_cfg = cfg + if current_cfg is None: + workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides, + padding, layout, out_layout, + conv_out.dtype) + current_cfg = _query_dispatcher(workload) + args = [s, current_cfg, data_vec, conv_out, outs[0]] + if is_kernel_1x1: + conv2d_avx_1x1._schedule_conv_NCHWc(*args) + else: + conv2d_avx_common._schedule_conv_NCHWc(*args) scheduled_ops.append(op) diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index bace7451d665..96affc7b9d23 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -3,11 +3,11 @@ from __future__ import absolute_import as _abs from collections import namedtuple import tvm +from tvm.autotvm.task import ConfigEntity + import topi -from ..util import get_const_tuple -from ..nn.conv2d import _get_schedule, _get_workload -from ..nn.util import infer_pad, infer_stride +from ..nn.util import infer_pad from ..nn.pad import pad from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .check_targets import check_skylake @@ -42,62 +42,51 @@ def _get_default_schedule(wkl, simd_width): raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) -def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): - assert layout == 'NCHW', "only support NCHW convolution for AVX" - wkl = _get_workload(data, kernel, stride, padding, out_dtype) - sch = _get_schedule(wkl) - - HPAD, WPAD = wkl.hpad, wkl.wpad - HSTR, WSTR = wkl.hstride, wkl.wstride +def _fallback_schedule(wkl, simd_width): + batch_size, in_channel, height, width, _ = wkl[1] + out_channel, _, hkernel, wkernel, _ = wkl[2] + HPAD, WPAD = wkl[4] + HSTR, WSTR = wkl[3] + out_height = (height + 2 * HPAD - hkernel) // HSTR + 1 + out_width = (width + 2 * WPAD - wkernel) // WSTR + 1 - batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape) - num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape) - - pad_height = in_height + 2 * HPAD - pad_width = in_width + 2 * WPAD - - out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1 - out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1 + oc_bn = 1 + for bn in range(simd_width, 0, -1): + if out_channel % bn == 0: + oc_bn = bn + break - DOPAD = (HPAD != 0 or WPAD != 0) - if DOPAD: - data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") - else: - data_pad = data - shape = (batch_size, in_channel // sch.ic_bn, pad_height, pad_width, sch.ic_bn) - data_vec = tvm.compute(shape, lambda n, C, h, w, c: data_pad[n, C * sch.ic_bn + c, h, w]) + ic_bn = 1 + for bn in range(oc_bn, 0, -1): + if in_channel % bn == 0: + ic_bn = bn + break - shape = (num_filter // sch.oc_bn, in_channel // sch.ic_bn, sch.ic_bn, sch.oc_bn, 1, 1) - kernel_vec = tvm.compute(shape, lambda CO, CI, ci, co, h, w: - kernel[CO * sch.oc_bn + co, CI * sch.ic_bn + ci, h, w], - name='kernel_vec') + for ow_factor in range(out_width, 0, -1): + if out_width % ow_factor == 0: + for oh_factor in range(out_height, 0, -1): + if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [in_channel // ic_bn, ic_bn]], + ["tile_oc", "sp", [out_channel // oc_bn, oc_bn]], + ["tile_oh", "ot", oh_factor], + ["tile_ow", "sp", [out_width // ow_factor, + ow_factor]],], + "t": ""} + return ConfigEntity.from_json_dict(cfg_dict) - oshape = (batch_size, num_filter // sch.oc_bn, out_height, out_width, sch.oc_bn) - ic = tvm.reduce_axis((0, in_channel), name='ic') - conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_vec[n, ic//sch.ic_bn, oh*HSTR, ow*WSTR, ic%sch.ic_bn] * - kernel_vec[oc_chunk, ic//sch.ic_bn, ic%sch.ic_bn, oc_block, 0, 0], - axis=[ic]), name='conv') + raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) - oshape = (batch_size, num_filter, out_height, out_width) - unpack = tvm.compute(oshape, lambda n, oc, oh, ow: - conv[n, oc // sch.oc_bn, oh, ow, oc % sch.oc_bn], - tag='conv2d_nchw') - return unpack +def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): + # fetch schedule + ic_bn, oc_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], + cfg["tile_oh"].val, cfg["tile_ow"].size[-1]) -def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, output, last): # no stride and padding info here padding = infer_pad(data, data_pad) - if data_pad is None: - stride = infer_stride(data, kernel, output) - else: - stride = infer_stride(data_pad, kernel, output) - - wkl = _get_workload(data, kernel, stride, padding, output.dtype) - sch = _get_schedule(wkl) - - HPAD, WPAD = wkl.hpad, wkl.wpad + HPAD, WPAD = padding DOPAD = (HPAD != 0 or WPAD != 0) A, W = data, kernel_vec @@ -112,7 +101,7 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou # schedule kernel pack oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) - if sch.oc_bn > 1: + if oc_bn > 1: s[W].vectorize(oc_block) parallel_axis = s[W].fuse(oc_chunk, oh) s[W].parallel(parallel_axis) @@ -121,17 +110,17 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou CC = s.cache_write(C, 'global') batch, oc_chunk, oh, ow, oc_block = s[C].op.axis - oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor) + oh_outer, oh_inner = s[C].split(oh, factor=oh_factor) s[C].vectorize(oc_block) s[CC].compute_at(s[C], oh_outer) _, oc_chunk, oh, ow, oc_block = s[CC].op.axis - ic, = s[CC].op.reduce_axis + ic, _, _ = s[CC].op.reduce_axis - ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) + ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn) - oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor) - ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor) + oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor) s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block) s[CC].vectorize(oc_block) @@ -143,9 +132,9 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou s[O0].compute_inline() batch, oc, oh, ow = s[O].op.axis - oc_chunk, oc_block = s[O].split(oc, factor=sch.oc_bn) - oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor) - ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor) + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh_outer) @@ -157,33 +146,11 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou return s -def _declaration_conv_NCHWc(wkl, sch, data, kernel): - out_dtype = wkl.out_dtype - HPAD, WPAD = wkl.hpad, wkl.wpad - HSTR, WSTR = wkl.hstride, wkl.wstride - - batch_size = data.shape[0] - out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 - out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 - - DOPAD = (HPAD != 0 or WPAD != 0) - if DOPAD: - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - else: - data_pad = data - - oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn) - ic = tvm.reduce_axis((0, wkl.in_filter), name='ic') - conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_pad[n, ic//sch.ic_bn, oh*HSTR, ow*WSTR, ic%sch.ic_bn] - .astype(out_dtype) * - kernel[oc_chunk, ic // sch.ic_bn, ic % sch.ic_bn, oc_block, 0, 0], - axis=[ic]), name='conv2d_NCHWc', tag='conv2d_NCHWc') - - return conv - +def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): + # fetch schedule + ic_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oh"].val, + cfg["tile_ow"].size[-1]) -def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last): # schedule data A = data if isinstance(s[A].op, tvm.tensor.ComputeOp): @@ -195,8 +162,8 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last): CC = s.cache_write(C, 'global') batch, oc_chunk, oh, ow, oc_block = s[C].op.axis - oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor) - ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor) + oh_outer, oh_inner = s[C].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[C].split(ow, factor=ow_factor) s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) s[C].vectorize(oc_block) @@ -206,12 +173,12 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last): s[C].parallel(parallel_axis) _, oc_chunk, oh, ow, oc_block = s[CC].op.axis - ic, = s[CC].op.reduce_axis + ic, _, _ = s[CC].op.reduce_axis - ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) + ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn) - oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor) - ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor) + oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor) s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block) s[CC].fuse(oc_chunk, oh_outer) @@ -222,8 +189,8 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last): if C != O: batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor) - ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor) + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh_outer) diff --git a/topi/python/topi/x86/conv2d_avx_common.py b/topi/python/topi/x86/conv2d_avx_common.py index 0d7aba23d236..eaa3d15e64b0 100644 --- a/topi/python/topi/x86/conv2d_avx_common.py +++ b/topi/python/topi/x86/conv2d_avx_common.py @@ -3,10 +3,9 @@ from __future__ import absolute_import as _abs from collections import namedtuple import tvm +from tvm.autotvm.task import ConfigEntity -from ..util import get_const_tuple -from ..nn.conv2d import _get_schedule, _get_workload -from ..nn.util import infer_pad, infer_stride +from ..nn.util import infer_pad from ..nn.pad import pad from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .check_targets import check_skylake @@ -17,7 +16,6 @@ def _get_default_schedule(wkl, simd_width): HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride - out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 oc_bn = 1 @@ -41,78 +39,49 @@ def _get_default_schedule(wkl, simd_width): return AVXConvCommonFwd(ic_bn, oc_bn, reg_n, False) -def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): - out_dtype = data.dtype if out_dtype is None else out_dtype - assert layout == 'NCHW', "only support NCHW convolution for AVX" - wkl = _get_workload(data, kernel, stride, padding, out_dtype) - sch = _get_schedule(wkl) - - HPAD, WPAD = wkl.hpad, wkl.wpad - HSTR, WSTR = wkl.hstride, wkl.wstride - - batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape) - num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape) - - pad_height = in_height + 2 * HPAD - pad_width = in_width + 2 * WPAD +def _fallback_schedule(wkl, simd_width): + batch_size, in_channel, height, width, _ = wkl[1] + out_channel, _, hkernel, wkernel, _ = wkl[2] + HPAD, WPAD = wkl[4] + HSTR, WSTR = wkl[3] + out_width = (width + 2 * WPAD - wkernel) // WSTR + 1 - out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1 - out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1 + oc_bn = 1 + for bn in range(simd_width, 0, -1): + if out_channel % bn == 0: + oc_bn = bn + break - # pack data - DOPAD = (HPAD != 0 or WPAD != 0) - if DOPAD: - data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") - else: - data_pad = data + ic_bn = 1 + for bn in range(oc_bn, 0, -1): + if in_channel % bn == 0: + ic_bn = bn + break - shape = (batch_size, in_channel // sch.ic_bn, pad_height, sch.ic_bn, pad_width) - data_vec = tvm.compute(shape, - lambda n, C, h, c, w: data_pad[n, C * sch.ic_bn + c, h, w], - name='data_vec') + reg_n = 1 + for n in range(31, 0, -1): + if out_width % n == 0: + reg_n = n + break - # pack kernel - shape = (num_filter//sch.oc_bn, in_channel//sch.ic_bn, - kernel_height, kernel_width, sch.ic_bn, sch.oc_bn) - kernel_vec = tvm.compute(shape, lambda CO, CI, h, w, ci, co: - kernel[CO * sch.oc_bn + co, CI * sch.ic_bn + ci, h, w], - name='kernel_vec') + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [in_channel // ic_bn, ic_bn]], + ["tile_oc", "sp", [out_channel // oc_bn, oc_bn]], + ["tile_ow", "sp", [out_width // reg_n, reg_n]], + ["unroll_kw", "ot", False]], + "t": ""} + return ConfigEntity.from_json_dict(cfg_dict) - # convolution - oshape = (batch_size, num_filter//sch.oc_bn, out_height, out_width, sch.oc_bn) - unpack_shape = (batch_size, num_filter, out_height, out_width) - ic = tvm.reduce_axis((0, in_channel), name='ic') - kh = tvm.reduce_axis((0, kernel_height), name='kh') - kw = tvm.reduce_axis((0, kernel_width), name='kw') +def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): + # fetch schedule + ic_bn, oc_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], + cfg["tile_ow"].size[-1], cfg["unroll_kw"].val) - conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_vec[n, ic//sch.ic_bn, oh*HSTR+kh, ic%sch.ic_bn, ow*WSTR+kw] - .astype(out_dtype) * - kernel_vec[oc_chunk, ic//sch.ic_bn, kh, kw, ic%sch.ic_bn, oc_block] - .astype(out_dtype), - axis=[ic, kh, kw]), - name='conv') - - unpack = tvm.compute(unpack_shape, - lambda n, c, h, w: conv[n, c // sch.oc_bn, h, w, c % sch.oc_bn] - .astype(out_dtype), - name='output_unpack', - tag='conv2d_nchw') - return unpack - - -def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, output, last): # no stride and padding info here padding = infer_pad(data, data_pad) - if data_pad is None: - stride = infer_stride(data, kernel, output) - else: - stride = infer_stride(data_pad, kernel, output) - wkl = _get_workload(data, kernel, stride, padding, output.dtype) - sch = _get_schedule(wkl) - - HPAD, WPAD = wkl.hpad, wkl.wpad + HPAD, WPAD = padding DOPAD = (HPAD != 0 or WPAD != 0) A, W = data, kernel_vec @@ -128,7 +97,7 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou # schedule kernel pack oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) - if sch.oc_bn > 1: + if oc_bn > 1: s[W].vectorize(oc_block) parallel_axis = s[W].fuse(oc_chunk, oh) s[W].parallel(parallel_axis) @@ -138,7 +107,7 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou CC = s.cache_write(C, 'global') _, oc_chunk, oh, ow, oc_block = s[C].op.axis - ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n) + ow_chunk, ow_block = s[C].split(ow, factor=reg_n) s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) s[C].fuse(oc_chunk, oh) s[C].vectorize(oc_block) @@ -147,10 +116,10 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou _, oc_chunk, oh, ow, oc_block = s[CC].op.axis ic, kh, kw = s[CC].op.reduce_axis - ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n) - ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) + ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) + ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn) - if sch.unroll_kw: + if unroll_kw: s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block) s[CC].unroll(kw) else: @@ -164,8 +133,8 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou s[O0].compute_inline() batch, oc, oh, ow = s[O].op.axis - ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n) - oc_chunk, oc_block = s[O].split(oc, factor=sch.oc_bn) + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh) s[C].compute_at(s[O], parallel_axis) @@ -176,39 +145,11 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou return s -def _declaration_conv_NCHWc(wkl, sch, data, kernel): - out_dtype = wkl.out_dtype - HPAD, WPAD = wkl.hpad, wkl.wpad - HSTR, WSTR = wkl.hstride, wkl.wstride - - batch_size = data.shape[0] - out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 - out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 - - # pack data - DOPAD = (HPAD != 0 or WPAD != 0) - if DOPAD: - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - else: - data_pad = data - - # convolution - oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn) - - ic = tvm.reduce_axis((0, wkl.in_filter), name='ic') - kh = tvm.reduce_axis((0, wkl.hkernel), name='kh') - kw = tvm.reduce_axis((0, wkl.wkernel), name='kw') - - conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_pad[n, ic//sch.ic_bn, oh*HSTR+kh, ow*WSTR+kw, ic%sch.ic_bn] - .astype(out_dtype) * - kernel[oc_chunk, ic//sch.ic_bn, kh, kw, ic%sch.ic_bn, oc_block], - axis=[ic, kh, kw]), name='conv2d_NCHWc', tag="conv2d_NCHWc") +def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): + # fetch schedule + ic_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_ow"].size[-1], + cfg["unroll_kw"].val) - return conv - - -def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last): # schedule data A = data if isinstance(s[A].op, tvm.tensor.ComputeOp): @@ -221,7 +162,7 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last): CC = s.cache_write(C, 'global') _, oc_chunk, oh, ow, oc_block = s[C].op.axis - ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n) + ow_chunk, ow_block = s[C].split(ow, factor=reg_n) s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) parallel_axis = s[C].fuse(oc_chunk, oh) s[C].vectorize(oc_block) @@ -232,10 +173,10 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last): _, oc_chunk, oh, ow, oc_block = s[CC].op.axis ic, kh, kw = s[CC].op.reduce_axis - ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n) - ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) + ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) + ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn) - if sch.unroll_kw: + if unroll_kw: s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block) s[CC].unroll(kw) else: @@ -246,7 +187,7 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last): if C != O: batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n) + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh) s[C].compute_at(s[O], parallel_axis) diff --git a/tutorials/autotvm/tune_nnvm_x86.py b/tutorials/autotvm/tune_nnvm_x86.py new file mode 100644 index 000000000000..ddd91f584c08 --- /dev/null +++ b/tutorials/autotvm/tune_nnvm_x86.py @@ -0,0 +1,220 @@ +""" +Auto-tuning a convolutional network for x86 CPU +==================================================== +**Author**: `Yao Wang `_ + +This is a tutorial about how to tune convolution neural network +for x86 cpu. +""" +import os +import numpy as np + +import nnvm.testing +import nnvm.compiler +import tvm +from tvm import autotvm +from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner +from topi.x86.conv2d import conv_NCHWc_arg_to_workload +import tvm.contrib.graph_runtime as runtime + +################################################################# +# Define network +# -------------- +# First we need to define the network in nnvm symbol API. +# We can load some pre-defined network from :code:`nnvm.testing`. +# We can also load models from MXNet, ONNX and TensorFlow (see NNVM +# tutorials :ref:`tutorial-nnvm` for more details). +# +# In this tutorial, we choose resnet-18 as tuning example. + +def get_network(name, batch_size): + """Get the symbol definition and random weight of a network""" + input_shape = (batch_size, 3, 224, 224) + output_shape = (batch_size, 1000) + + if "resnet" in name: + n_layer = int(name.split('-')[1]) + net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size) + elif "vgg" in name: + n_layer = int(name.split('-')[1]) + net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size) + elif name == 'mobilenet': + net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size) + elif name == 'squeezenet_v1.1': + net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1') + elif name == 'inception_v3': + input_shape = (1, 3, 299, 299) + net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size) + elif name == 'custom': + # an example for custom network + from nnvm.testing import utils + net = nnvm.sym.Variable('data') + net = nnvm.sym.conv2d(net, channels=4, kernel_size=(3,3), padding=(1,1)) + net = nnvm.sym.flatten(net) + net = nnvm.sym.dense(net, units=1000) + net, params = utils.create_workload(net, batch_size, (3, 224, 224)) + elif name == 'mxnet': + # an example for mxnet model + from mxnet.gluon.model_zoo.vision import get_model + block = get_model('resnet18_v1', pretrained=True) + net, params = nnvm.frontend.from_mxnet(block) + net = nnvm.sym.softmax(net) + else: + raise ValueError("Unsupported network: " + name) + + return net, params, input_shape, output_shape + +# Replace "llvm" with the correct target of your cpu. +# For example, for AWS EC2 c5 instance with Intel Xeon +# Platinum 8000 series, the target should be "llvm -mcpu=skylake-avx512". +# For AWS EC2 c4 instance with Intel Xeon E5-2666 v3, it should be +# "llvm -mcpu=core-avx2". +target = "llvm" + +batch_size = 1 +dtype = "float32" +model_name = "resnet-18" +log_file = "%s.log" % model_name + +# Set number of threads used for tuning based on the number of +# physical cpu cores on your machine. +num_threads = 1 +os.environ["TVM_NUM_THREADS"] = str(num_threads) + + +################################################################# +# Configure tensor tuning settings and create tasks +# ------------------------------------------------- +# To get better kernel execution performance on x86 cpu, +# we need to change data layout of convolution kernel from +# "NCHW" to "NCHWc". To deal with this situation, we define +# conv2d_NCHWc operator in topi. We will tune this operator +# instead of plain conv2d. +# +# We will use local mode for tuning configuration. RPC tracker +# mode can be setup similarly to the approach in +# :ref:`tune_nnvm_arm` tutorial. + +tuning_option = { + 'log_filename': log_file, + 'tuner': 'random', + 'early_stopping': None, + + 'measure_option': autotvm.measure_option( + builder=autotvm.LocalBuilder(), + runner=autotvm.LocalRunner(number=10, repeat=1, + min_repeat_ms=1000), + ), +} + +# You can skip the implementation of this function for this tutorial. +def tune_kernels(tasks, + measure_option, + tuner='gridsearch', + early_stopping=None, + log_filename='tuning.log'): + + for i, tsk in enumerate(tasks): + prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) + + # converting conv2d tasks to conv2d_NCHWc tasks + data, kernel, strides, padding, layout, dtype = tsk.args + kernel_size = (kernel[1][2], kernel[1][3]) + data_plc = tvm.placeholder(data[1], name="data") + kernel_plc = tvm.placeholder(kernel[1], name="kernel") + args = [data_plc, kernel_plc, data[1][1], kernel_size, strides, + padding, layout, layout, dtype] + args = autotvm.task.nnvm_integration.serialize_args(args) + task = autotvm.task.create("topi_x86_conv2d_NCHWc", args=args, target=target) + task.workload = conv_NCHWc_arg_to_workload(data_plc, kernel_plc, kernel_size, + strides, padding, layout, layout, dtype) + + # create tuner + if tuner == 'xgb' or tuner == 'xgb-rank': + tuner_obj = XGBTuner(task, loss_type='rank') + elif tuner == 'ga': + tuner_obj = GATuner(task, pop_size=50) + elif tuner == 'random': + tuner_obj = RandomTuner(task) + elif tuner == 'gridsearch': + tuner_obj = GridSearchTuner(task) + else: + raise ValueError("Invalid tuner: " + tuner) + + # do tuning + n_trial=len(task.config_space) + tuner_obj.tune(n_trial=n_trial, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(n_trial, prefix=prefix), + autotvm.callback.log_to_file(log_filename)]) + + +######################################################################## +# Finally, we launch tuning jobs and evaluate the end-to-end performance. + +def tune_and_evaluate(tuning_opt): + # extract workloads from nnvm graph + print("Extract tasks...") + net, params, data_shape, out_shape = get_network(model_name, batch_size) + tasks = autotvm.task.extract_from_graph(net, target=target, + shape={'data': data_shape}, dtype=dtype, + symbols=(nnvm.sym.conv2d,)) + + # run tuning tasks + print("Tuning...") + tune_kernels(tasks, **tuning_opt) + + # compile kernels with history best records + with autotvm.apply_history_best(log_file): + print("Compile...") + with nnvm.compiler.build_config(opt_level=3): + graph, lib, params = nnvm.compiler.build( + net, target=target, shape={'data': data_shape}, params=params, dtype=dtype) + + # upload parameters to device + ctx = tvm.cpu() + data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype)) + module = runtime.create(graph, lib, ctx) + module.set_input('data', data_tvm) + module.set_input(**params) + + # evaluate + print("Evaluate inference time cost...") + ftimer = module.module.time_evaluator("run", ctx, number=100, repeat=3) + prof_res = np.array(ftimer().results) * 1000 # convert to millisecond + print("Mean inference time (std dev): %.2f ms (%.2f ms)" % + (np.mean(prof_res), np.std(prof_res))) + +# We do not run the tuning in our webpage server since it takes too long. +# Uncomment the following line to run it by yourself. + +# tune_and_evaluate(tuning_option) + +###################################################################### +# Sample Output +# ------------- +# The tuning needs to compile many programs and extract feature from them. +# So a high performance CPU is recommended. +# One sample output is listed below. +# +# .. code-block:: bash +# +# Extract tasks... +# Tuning... +# [Task 1/12] Current/Best: 598.05/2497.63 GFLOPS | Progress: (252/252) | 1357.95 s Done. +# [Task 2/12] Current/Best: 522.63/2279.24 GFLOPS | Progress: (784/784) | 3989.60 s Done. +# [Task 3/12] Current/Best: 447.33/1927.69 GFLOPS | Progress: (784/784) | 3869.14 s Done. +# [Task 4/12] Current/Best: 481.11/1912.34 GFLOPS | Progress: (672/672) | 3274.25 s Done. +# [Task 5/12] Current/Best: 414.09/1598.45 GFLOPS | Progress: (672/672) | 2720.78 s Done. +# [Task 6/12] Current/Best: 508.96/2273.20 GFLOPS | Progress: (768/768) | 3718.75 s Done. +# [Task 7/12] Current/Best: 469.14/1955.79 GFLOPS | Progress: (576/576) | 2665.67 s Done. +# [Task 8/12] Current/Best: 230.91/1658.97 GFLOPS | Progress: (576/576) | 2435.01 s Done. +# [Task 9/12] Current/Best: 487.75/2295.19 GFLOPS | Progress: (648/648) | 3009.95 s Done. +# [Task 10/12] Current/Best: 182.33/1734.45 GFLOPS | Progress: (360/360) | 1755.06 s Done. +# [Task 11/12] Current/Best: 372.18/1745.15 GFLOPS | Progress: (360/360) | 1684.50 s Done. +# [Task 12/12] Current/Best: 215.34/2271.11 GFLOPS | Progress: (400/400) | 2128.74 s Done. +# Compile... +# Evaluate inference time cost... +# Mean inference time (std dev): 3.16 ms (0.03 ms)