diff --git a/Makefile b/Makefile index 1c8d70ecc695..a4b41b8d8371 100644 --- a/Makefile +++ b/Makefile @@ -66,8 +66,8 @@ $(warning "USE_MKL2017 is deprecated. We will switch to USE_MKLDNN.") endif ifeq ($(USE_MKLDNN), 1) - MKLDNNROOT = $(ROOTDIR)/3rdparty/mkldnn/install - MKLROOT = $(ROOTDIR)/3rdparty/mkldnn/install + MKLDNNROOT = $(ROOTDIR)/3rdparty/mkldnn/build/install + MKLROOT = $(ROOTDIR)/3rdparty/mkldnn/build/install export USE_MKLML = 1 endif diff --git a/example/quantization/imagenet_gen_qsym.py b/example/quantization/imagenet_gen_qsym.py index 85474b663fae..8a2818c4bca0 100644 --- a/example/quantization/imagenet_gen_qsym.py +++ b/example/quantization/imagenet_gen_qsym.py @@ -92,7 +92,7 @@ def save_params(fname, arg_params, aux_params, logger=None): ' thresholds. This mode is expected to produce the best inference accuracy of all three' ' kinds of quantized models if the calibration dataset is representative enough of the' ' inference dataset.') - parser.add_argument('--quantized-dtype', type=str, default='int8', + parser.add_argument('--quantized-dtype', type=str, default='int8', choices=['int8', 'uint8'], help='quantization destination data type for input data') args = parser.parse_args() diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py new file mode 100644 index 000000000000..e06276115154 --- /dev/null +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import os +import logging +from common import modelzoo +import mxnet as mx +from mxnet.contrib.quantization import * +from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array +import ctypes + + +def download_calib_dataset(dataset_url, calib_dataset, logger=None): + if logger is not None: + logger.info('Downloading calibration dataset from %s to %s' % (dataset_url, calib_dataset)) + mx.test_utils.download(dataset_url, calib_dataset) + + +def download_model(model_name, logger=None): + dir_path = os.path.dirname(os.path.realpath(__file__)) + model_path = os.path.join(dir_path, 'model') + if logger is not None: + logger.info('Downloading model %s... into path %s' % (model_name, model_path)) + return modelzoo.download_model(args.model, os.path.join(dir_path, 'model')) + + +def save_symbol(fname, sym, logger=None): + if logger is not None: + logger.info('Saving symbol into file at %s' % fname) + sym.save(fname) + + +def save_params(fname, arg_params, aux_params, logger=None): + if logger is not None: + logger.info('Saving params into file at %s' % fname) + save_dict = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in arg_params.items()} + save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()}) + mx.nd.save(fname, save_dict) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model with MKL-DNN support') + parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'], + help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn') + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--label-name', type=str, default='softmax_label') + parser.add_argument('--calib-dataset', type=str, default='data/val_256_q90.rec', + help='path of the calibration dataset') + parser.add_argument('--image-shape', type=str, default='3,224,224') + parser.add_argument('--data-nthreads', type=int, default=60, + help='number of threads for data decoding') + parser.add_argument('--num-calib-batches', type=int, default=10, + help='number of batches for calibration') + parser.add_argument('--exclude-first-conv', action='store_true', default=True, + help='excluding quantizing the first conv layer since the' + ' input data may have negative value which doesn\'t support at moment' ) + parser.add_argument('--shuffle-dataset', action='store_true', default=True, + help='shuffle the calibration dataset') + parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304, + help='shuffling chunk seed, see' + ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter' + ' for more details') + parser.add_argument('--shuffle-seed', type=int, default=48564309, + help='shuffling seed, see' + ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter' + ' for more details') + parser.add_argument('--calib-mode', type=str, default='entropy', + help='calibration mode used for generating calibration table for the quantized symbol; supports' + ' 1. none: no calibration will be used. The thresholds for quantization will be calculated' + ' on the fly. This will result in inference speed slowdown and loss of accuracy' + ' in general.' + ' 2. naive: simply take min and max values of layer outputs as thresholds for' + ' quantization. In general, the inference accuracy worsens with more examples used in' + ' calibration. It is recommended to use `entropy` mode as it produces more accurate' + ' inference results.' + ' 3. entropy: calculate KL divergence of the fp32 output and quantized output for optimal' + ' thresholds. This mode is expected to produce the best inference accuracy of all three' + ' kinds of quantized models if the calibration dataset is representative enough of the' + ' inference dataset.') + parser.add_argument('--quantized-dtype', type=str, default='uint8', + choices=['int8', 'uint8'], + help='quantization destination data type for input data') + parser.add_argument('--enable-calib-quantize', type=bool, default=True, + help='If enabled, the quantize op will ' + 'be calibrated offline if calibration mode is ' + 'enabled') + args = parser.parse_args() + ctx = mx.cpu(0) + logging.basicConfig() + logger = logging.getLogger('logger') + logger.setLevel(logging.INFO) + + logger.info('shuffle_dataset=%s' % args.shuffle_dataset) + + calib_mode = args.calib_mode + logger.info('calibration mode set to %s' % calib_mode) + + # download calibration dataset + if calib_mode != 'none': + download_calib_dataset('http://data.mxnet.io/data/val_256_q90.rec', args.calib_dataset) + + # download model + prefix, epoch = download_model(model_name=args.model, logger=logger) + sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + + sym = sym.get_backend_symbol('MKLDNN') + + # get batch size + batch_size = args.batch_size + logger.info('batch size = %d for calibration' % batch_size) + + # get number of batches for calibration + num_calib_batches = args.num_calib_batches + if calib_mode == 'none': + logger.info('skip calibration step as calib_mode is none') + else: + logger.info('number of batches = %d for calibration' % num_calib_batches) + + # get number of threads for decoding the dataset + data_nthreads = args.data_nthreads + + # get image shape + image_shape = args.image_shape + + exclude_first_conv = args.exclude_first_conv + excluded_sym_names = [] + if args.model == 'imagenet1k-resnet-152': + rgb_mean = '0,0,0' + calib_layer = lambda name: name.endswith('_output') + excluded_sym_names += ['flatten0', 'fc1'] + if exclude_first_conv: + excluded_sym_names += ['conv0', 'pooling0'] + elif args.model == 'imagenet1k-inception-bn': + rgb_mean = '123.68,116.779,103.939' + calib_layer = lambda name: name.endswith('_output') + excluded_sym_names += ['flatten', 'fc1'] + if exclude_first_conv: + excluded_sym_names += ['conv_1'] + else: + raise ValueError('model %s is not supported in this script' % args.model) + + label_name = args.label_name + logger.info('label_name = %s' % label_name) + + data_shape = tuple([int(i) for i in image_shape.split(',')]) + logger.info('Input data shape = %s' % str(data_shape)) + + logger.info('rgb_mean = %s' % rgb_mean) + rgb_mean = [float(i) for i in rgb_mean.split(',')] + mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]} + + if calib_mode == 'none': + logger.info('Quantizing FP32 model %s' % args.model) + qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, + ctx=ctx, excluded_sym_names=excluded_sym_names, + calib_mode=calib_mode, quantized_dtype=args.quantized_dtype, + logger=logger) + sym_name = '%s-symbol.json' % (prefix + '-quantized') + else: + logger.info('Creating ImageRecordIter for reading calibration dataset') + data = mx.io.ImageRecordIter(path_imgrec=args.calib_dataset, + label_width=1, + preprocess_threads=data_nthreads, + batch_size=batch_size, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=args.shuffle_dataset, + shuffle_chunk_seed=args.shuffle_chunk_seed, + seed=args.shuffle_seed, + **mean_args) + + qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, + ctx=ctx, excluded_sym_names=excluded_sym_names, + calib_mode=calib_mode, calib_data=data, + num_calib_examples=num_calib_batches * batch_size, + calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, + label_names=(label_name,), calib_quantize_op = True, + logger=logger) + if calib_mode == 'entropy': + suffix = '-quantized-%dbatches-entropy' % num_calib_batches + elif calib_mode == 'naive': + suffix = '-quantized-%dbatches-naive' % num_calib_batches + else: + raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`' + % calib_mode) + sym_name = '%s-symbol.json' % (prefix + suffix) + qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE') + save_symbol(sym_name, qsym, logger) + param_name = '%s-%04d.params' % (prefix + '-quantized', epoch) + save_params(param_name, qarg_params, aux_params, logger) diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py index 85649530aa0b..286e49ea4401 100644 --- a/example/quantization/imagenet_inference.py +++ b/example/quantization/imagenet_inference.py @@ -129,7 +129,7 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples, ctx = mx.cpu(0) else: raise ValueError('ctx %s is not supported in this script' % args.ctx) - + logging.basicConfig() logger = logging.getLogger('logger') logger.setLevel(logging.INFO) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index a01cc6a77940..dc33c95437f9 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1542,18 +1542,17 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, * \param sym_handle symbol to be converted * \param ret_sym_handle quantized symbol result * \param num_excluded_symbols number of layers excluded from being quantized in the input symbol - * \param excluded_symbols array of symbols to be excluded from being quantized + * \param excluded_symbols op names to be excluded from being quantized * \param num_offline number of parameters that are quantized offline * \param offline_params array of c strings representing the names of params quantized offline * \param quantized_dtype the quantized destination type for input data. + * \param calib_quantize whether calibrate quantize op with offline calibration data. */ -MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, - SymbolHandle *ret_sym_handle, +MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const mx_uint num_excluded_symbols, - const SymbolHandle *excluded_symbols, - const mx_uint num_offline, - const char **offline_params, - const char *quantized_dtype); + const char **excluded_symbols, + const mx_uint num_offline, const char **offline_params, + const char *quantized_dtype, const bool calib_quantize); /*! * \brief Set calibration table to node attributes in the sym @@ -1571,6 +1570,15 @@ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, const float* high_quantiles, SymbolHandle* ret_sym_handle); +/*! + * \brief Run subgraph pass based on the backend provided + * \param sym_handle symbol to be converted + * \param backend backend names for subgraph pass + * \param ret_sym_handle returned symbol + */ +MXNET_DLL int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend, + SymbolHandle *ret_sym_handle); + //-------------------------------------------- // Part 4: Executor interface //-------------------------------------------- diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index afae5dcfcffe..e877d35dbb5b 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -667,6 +667,12 @@ class NDArray { } #if MXNET_USE_MKLDNN == 1 + /* + * Create NDArray from mkldnn memory. + * mkldnn_mem The mkldnn memory to be managed. + * static_data If true, mkldnn memory won't be freed on destruction. + */ + explicit NDArray(const mkldnn::memory *mkldnn_mem, bool static_data = true); /* * Test if the data is stored in one of special MKLDNN format. */ @@ -742,6 +748,11 @@ class NDArray { * It's used by FullyConnected right now. */ NDArray MKLDNNDataReshape(const TShape &shape) const; + + /*! + * \ Fix mkldnn memory descriptor mismatch from NDArray. + */ + void UpdateMKLDNNMemDesc(); #endif /*! diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index aa5d4e6de784..dd818457f827 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -300,6 +300,14 @@ using FQuantizedOp = std::function; */ using FNeedRequantize = std::function; +/*! + * \brief Register a function to determine if the input of a quantized operator + * needs to be quantized. This is usually used for the quantized operators + * which can handle fp32 inputs directly. + */ +using FAvoidQuantizeInput = std::function; + } // namespace mxnet #endif // MXNET_OP_ATTR_TYPES_H_ diff --git a/mkldnn.mk b/mkldnn.mk index 1be0704dcde1..d79bbe7d2a0e 100644 --- a/mkldnn.mk +++ b/mkldnn.mk @@ -47,7 +47,7 @@ $(MKLDNN_LIBFILE): mkldnn_clean: $(RM) -r 3rdparty/mkldnn/build - $(RM) -r 3rdparty/mkldnn/install/* + $(RM) -r $(MKLDNNROOT) ifeq ($(USE_MKLDNN), 1) mkldnn: mkldnn_build diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 8df923908fec..3b04016351ad 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -40,7 +40,7 @@ from ..module import Module -def _quantize_params(qsym, params): +def _quantize_params(qsym, params, th_dict): """Given a quantized symbol and a dict of params that have not been quantized, generate quantized params. Currently only supports quantizing the arg_params with names of `weight` or `bias`, not aux_params. If `qsym` contains symbols @@ -53,6 +53,7 @@ def _quantize_params(qsym, params): qsym : Symbol Quantized symbol from FP32 symbol. params : dict of str->NDArray + th_dict: dict of min/max pairs of layers' output """ inputs_name = qsym.list_arguments() quantized_params = {} @@ -69,11 +70,18 @@ def _quantize_params(qsym, params): quantized_params[name+'_max'] = vmax elif name in params: quantized_params[name] = params[name] + elif name.endswith(('_min')): + output = name[: - len('_min')] + "_output" + if output in th_dict: + quantized_params[name] = ndarray.array([th_dict[output][0]]) + elif name.endswith(('_max')): + output = name[: - len('_min')] + "_output" + if output in th_dict: + quantized_params[name] = ndarray.array([th_dict[output][1]]) return quantized_params - def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, - quantized_dtype='int8'): + quantized_dtype='int8', calib_quantize_op=False): """Given a symbol object representing a neural network of data type FP32, quantize it into a INT8 network. @@ -81,22 +89,24 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, ---------- sym : Symbol FP32 neural network symbol. - excluded_symbols : list of symbols - Nodes in the network that users do not want to replace with a symbol of INT8 data type. + excluded_sym_names : list of strings + A list of strings representing the names of the symbols that users want to excluding + from being quantized. offline_params : list of strs Names of the parameters that users want to quantize offline. It's always recommended to quantize parameters offline so that quantizing parameters during the inference can be avoided. quantized_dtype: str The quantized destination type for input data. + calib_quantize_op : bool + Whether perform offline calibration for quantize op. """ num_excluded_symbols = 0 - excluded_handles = [] if excluded_symbols is not None: assert isinstance(excluded_symbols, list) num_excluded_symbols = len(excluded_symbols) - for s in excluded_symbols: - excluded_handles.append(s.handle) + else: + excluded_symbols = [] num_offline = 0 offline = [] @@ -109,10 +119,11 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, check_call(_LIB.MXQuantizeSymbol(sym.handle, ctypes.byref(out), mx_uint(num_excluded_symbols), - c_array(SymbolHandle, excluded_handles), + c_str_array(excluded_symbols), mx_uint(num_offline), c_array(ctypes.c_char_p, offline), - c_str(quantized_dtype))) + c_str(quantized_dtype), + ctypes.c_bool(calib_quantize_op))) return Symbol(out) @@ -254,9 +265,6 @@ def _smooth_distribution(p, eps=0.0001): # pylint: disable=line-too-long def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): """Given a dataset, find the optimal threshold for quantizing it. - The reference distribution is `q`, and the candidate distribution is `p`. - `q` is a truncated version of the original distribution. - Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf """ if isinstance(arr, NDArray): @@ -307,10 +315,10 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): right_outlier_count = np.sum(hist[p_bin_idx_stop:]) p[-1] += right_outlier_count # is_nonzeros[k] indicates whether hist[k] is nonzero - is_nonzeros = (p != 0).astype(np.int32) + is_nonzeros = (sliced_nd_hist != 0).astype(np.int32) # calculate how many bins should be merged to generate quantized distribution q - num_merged_bins = sliced_nd_hist.size // num_quantized_bins + num_merged_bins = p.size // num_quantized_bins # merge hist into num_quantized_bins bins for j in range(num_quantized_bins): start = j * num_merged_bins @@ -318,17 +326,17 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): quantized_bins[j] = sliced_nd_hist[start:stop].sum() quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() # expand quantized_bins into p.size bins - q = np.zeros(sliced_nd_hist.size, dtype=np.float32) + q = np.zeros(p.size, dtype=np.float32) for j in range(num_quantized_bins): start = j * num_merged_bins if j == num_quantized_bins - 1: - stop = len(is_nonzeros) + stop = -1 else: stop = start + num_merged_bins norm = is_nonzeros[start:stop].sum() if norm != 0: q[start:stop] = float(quantized_bins[j]) / float(norm) - q[p == 0] = 0 + q[sliced_nd_hist == 0] = 0 p = _smooth_distribution(p) # There is a chance that q is an invalid probability distribution. try: @@ -336,6 +344,7 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): except ValueError: divergence[i - num_half_quantized_bins] = float("inf") divergence[i - num_half_quantized_bins] = stats.entropy(p, q) + quantized_bins[:] = 0 min_divergence_idx = np.argmin(divergence) min_divergence = divergence[min_divergence_idx] @@ -363,7 +372,10 @@ def _get_optimal_thresholds(nd_dict, num_bins=8001, num_quantized_bins=255, logg _get_optimal_threshold(nd_dict[name], num_bins=num_bins, num_quantized_bins=num_quantized_bins) del nd_dict[name] # release the memory of ndarray - th_dict[name] = (-opt_th, opt_th) + if min_val < 0: + th_dict[name] = (-opt_th, opt_th) + else: + th_dict[name] = (0, opt_th) if logger is not None: logger.info('layer=%s, min_val=%f, max_val=%f, min_divergence=%f, optimal_threshold=%f' % (name, min_val, max_val, min_divergence, opt_th)) @@ -408,12 +420,11 @@ def _load_params(params, logger=logging): raise ValueError('Unsupported params provided. Must be either a path to the param file or' ' a pair of dictionaries representing arg_params and aux_params') - def quantize_model(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, calib_mode='entropy', calib_data=None, num_calib_examples=None, calib_layer=None, - quantized_dtype='int8', logger=logging): + quantized_dtype='int8', calib_quantize_op=False, logger=logging): """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run inference using the quantized models on Windows for now. @@ -466,6 +477,8 @@ def quantize_model(sym, arg_params, aux_params, quantized_dtype : str The quantized destination type for input data. Currently support 'int8' and 'uint8', default value is 'int8'. + calib_quantize_op: bool + Whether calibrate quantize op with its input calibration data. The quantize op's input should be in calib_layer logger : Object A logging object for printing information during the process of quantization. @@ -481,24 +494,17 @@ def quantize_model(sym, arg_params, aux_params, raise ValueError('excluded_sym_names must be a list of strings representing' ' the names of the symbols that will not be quantized,' ' while received type %s' % str(type(excluded_sym_names))) - excluded_syms = [] - if excluded_sym_names is not None: - for sym_name in excluded_sym_names: - nodes = sym.get_internals() - idx = nodes.list_outputs().index(sym_name + '_output') - excluded_syms.append(nodes[idx]) - logger.info('Quantizing symbol') + logger.info('Quantizing symbol') if quantized_dtype not in ('int8', 'uint8'): raise ValueError('unknown quantized_dtype %s received,' ' expected `int8` or `uint8`' % quantized_dtype) - qsym = _quantize_symbol(sym, excluded_symbols=excluded_syms, + qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, offline_params=list(arg_params.keys()), - quantized_dtype=quantized_dtype) - - logger.info('Quantizing parameters') - qarg_params = _quantize_params(qsym, arg_params) + quantized_dtype=quantized_dtype, + calib_quantize_op=calib_quantize_op) + th_dict = {} if calib_mode is not None and calib_mode != 'none': if not isinstance(ctx, Context): raise ValueError('currently only supports single ctx, while received %s' % str(ctx)) @@ -537,4 +543,7 @@ def quantize_model(sym, arg_params, aux_params, logger.info('Calibrating quantized symbol') qsym = _calibrate_quantized_sym(qsym, th_dict) + logger.info('Quantizing parameters') + qarg_params = _quantize_params(qsym, arg_params, th_dict) + return qsym, qarg_params, aux_params diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 554539b424ad..eaf22f3bec1a 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -2439,6 +2439,23 @@ def squeeze(self, *args, **kwargs): """ return op.squeeze(self, *args, **kwargs) + def get_backend_symbol(self, backend): + """Return symbol for target backend. + + Parameters + ---------- + backend : str + The backend names. + + Returns + ------- + out : Symbol + The created Symbol for target backend. + """ + out = SymbolHandle() + check_call(_LIB.MXGenBackendSubgraph(self.handle, c_str(backend), ctypes.byref(out))) + return Symbol(out) + def wait_to_read(self): raise NotImplementedForSymbol(self.wait_to_read, None) diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 35ecec7e11f6..d4625de80110 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -31,6 +31,7 @@ #include "./c_api_common.h" #include "../operator/operator_common.h" #include "../executor/exec_pass.h" +#include "../operator/subgraph/subgraph_property.h" namespace mxnet { namespace op { @@ -645,30 +646,29 @@ int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHand int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, - const mx_uint num_excluded_symbols, - const SymbolHandle *excluded_symbols, + const mx_uint num_excluded_op_names, + const char **excluded_op_names, const mx_uint num_offline, const char **offline_params, - const char *quantized_dtype) { + const char *quantized_dtype, + const bool calib_quantize) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); nnvm::Graph g = Symbol2Graph(*sym); - std::unordered_set excluded_nodes; - for (size_t i = 0; i < num_excluded_symbols; ++i) { - nnvm::Symbol* sym = static_cast(excluded_symbols[i]); - for (const auto& e : sym->outputs) { - excluded_nodes.emplace(e.node); - } + std::unordered_set excluded_node_names; + for (size_t i = 0; i < num_excluded_op_names; ++i) { + excluded_node_names.emplace(excluded_op_names[i]); } - g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_nodes)); std::unordered_set offline; for (size_t i = 0; i < num_offline; ++i) { offline.emplace(offline_params[i]); } std::string quantized_type(quantized_dtype); + g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_node_names)); g.attrs["offline_params"] = std::make_shared(std::move(offline)); g.attrs["quantized_dtype"] = std::make_shared(std::move(quantized_type)); + g.attrs["calib_quantize"] = std::make_shared(calib_quantize); g = ApplyPass(std::move(g), "QuantizeGraph"); s->outputs = g.outputs; *ret_sym_handle = s; @@ -696,3 +696,21 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, *ret_qsym_handle = s; API_END_HANDLE_ERROR(delete s); } + +int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend, + SymbolHandle *ret_sym_handle) { + nnvm::Symbol *s = new nnvm::Symbol(); + API_BEGIN(); + nnvm::Symbol *sym = static_cast(sym_handle); + *s = sym->Copy(); + nnvm::Graph g = Symbol2Graph(*s); + mxnet::op::SubgraphPropertyPtr property = + mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty( + backend); + g.attrs["subgraph_property"] = + std::make_shared(std::move(property)); + g = ApplyPass(std::move(g), "PartitionGraph"); + s->outputs = g.outputs; + *ret_sym_handle = s; + API_END_HANDLE_ERROR(delete s); +} diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 922917f79475..ed394234c525 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1532,14 +1532,14 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, // This is for bind flow. static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const std::string& prop_name, - const std::vector &in_args, + std::vector *in_args, const std::vector &aux_states, const Context& default_ctx, const std::map& ctx_map) { const std::vector input_names = src.ListInputNames(Symbol::kAll); const std::vector arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs); const std::vector aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates); - CHECK_EQ(arg_names.size(), in_args.size()); + CHECK_EQ(arg_names.size(), in_args->size()); CHECK_EQ(aux_names.size(), aux_states.size()); nnvm::ShapeVector arg_shapes; // all input shapes arg_shapes.reserve(input_names.size()); @@ -1547,7 +1547,7 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, arg_dtypes.reserve(input_names.size()); StorageTypeVector arg_stypes; // all input stypes arg_stypes.reserve(input_names.size()); - std::vector in_arg_ctxes(in_args.size()); + std::vector in_arg_ctxes(in_args->size()); std::vector aux_state_ctxes(aux_states.size()); size_t i1 = 0, i2 = 0; @@ -1561,15 +1561,32 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, } else { CHECK(i1 < arg_names.size()); CHECK_EQ(arg_names[i1], input_names[i]); - arg_shapes.push_back(in_args[i1].shape()); - arg_dtypes.push_back(in_args[i1].dtype()); - arg_stypes.push_back(in_args[i1].storage_type()); - in_arg_ctxes[i1] = in_args[i1].ctx(); + arg_shapes.push_back(in_args->at(i1).shape()); + arg_dtypes.push_back(in_args->at(i1).dtype()); + arg_stypes.push_back(in_args->at(i1).storage_type()); + in_arg_ctxes[i1] = in_args->at(i1).ctx(); ++i1; } } - return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, - default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes); + + // setup in_args_map + std::unordered_map in_args_map; + for (size_t i = 0; i < in_args->size(); ++i) { + in_args_map[arg_names[i]] = in_args->at(i); + } + auto result = PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, default_ctx, + ctx_map, in_arg_ctxes, aux_state_ctxes); + // Reorder in_args into new_in_args according to partitioned symbol input sequence + std::vector new_in_args(in_args->size()); + // get new symbol in_arg names + std::vector new_arg_names = result.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + CHECK_EQ(arg_names.size(), new_arg_names.size()); + in_args->clear(); + for (auto arg_name : new_arg_names) { + CHECK(in_args_map.count(arg_name)); + in_args->push_back(in_args_map[arg_name]); + } + return result; } } // namespace exec @@ -1613,12 +1630,13 @@ Executor *Executor::Bind(nnvm::Symbol symbol, const std::vector &aux_states, Executor* shared_exec) { auto exec = new exec::GraphExecutor(); + std::vector tmp_in_args = in_args; if (!exec->subgraph_property().empty()) { - symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), in_args, aux_states, + symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), &tmp_in_args, aux_states, default_ctx, group2ctx); } exec->Init(symbol, default_ctx, group2ctx, - in_args, arg_grad_store, grad_req_type, aux_states, + tmp_in_args, arg_grad_store, grad_req_type, aux_states, reinterpret_cast(shared_exec)); return exec; } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 47e0c5bbe75a..b9689db4d625 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -168,6 +168,18 @@ nnvm::Symbol NDArray::get_autograd_symbol() const { #if MXNET_USE_MKLDNN == 1 +NDArray::NDArray(const mkldnn::memory *mkldnn_mem, bool static_data) + : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { + auto mem_pd = mkldnn_mem->get_primitive_desc(); + auto mem_desc = mem_pd.desc(); + shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); + dtype_ = get_mxnet_type(mem_desc.data.data_type); + auto data = TBlob(mkldnn_mem->get_data_handle(), shape_, cpu::kDevMask, dtype_); + ptr_ = std::make_shared(data, 0); + ptr_->mkl_mem_ = std::make_shared(mem_pd, ptr_->shandle.dptr); + ptr_->static_data = static_data; +} + NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const { CHECK(!is_none()) << "NDArray is not initialized"; CHECK_GE(shape_.Size(), shape.Size()) @@ -477,24 +489,6 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) { mkl_mem_.reset(new MKLDNNMemory(pd, shandle.dptr)); } -/* - * Here we want to get MKLDNN memory whose primitive desc is exactly the same as - * the given one. operator== can't guarantee that. == can return true even if - * the formats are different. I need to double check its format. - */ -static inline mkldnn::memory *GetMKLDNNExact( - const mkldnn::memory *mem, mkldnn::memory::primitive_desc desc) { - mkldnn::memory::primitive_desc src_desc = mem->get_primitive_desc(); - if (desc == src_desc && desc.desc().data.format == src_desc.desc().data.format) { - return const_cast(mem); - } else { - std::shared_ptr ret(new mkldnn::memory( - desc, mem->get_data_handle())); - MKLDNNStream::Get()->RegisterMem(ret); - return ret.get(); - } -} - const mkldnn::memory *NDArray::GetMKLDNNData( const mkldnn::memory::primitive_desc &desc) const { if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { @@ -722,6 +716,21 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc & MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); return ptr_->mkl_mem_->GetRaw(); } + +void NDArray::UpdateMKLDNNMemDesc() { + const mkldnn::memory *mem = GetMKLDNNData(); + auto mem_desc = mem->get_primitive_desc().desc(); + auto this_dtype = get_mkldnn_type(dtype()); + if (this_dtype != mem_desc.data.data_type) { + mkldnn::memory::desc data_md( + mkldnn::memory::dims(mem_desc.data.dims, + mem_desc.data.dims + mem_desc.data.ndims), + this_dtype, static_cast(mem_desc.data.format)); + mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine()); + ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr)); + MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); + } +} #endif void NDArray::SetTBlob() const { @@ -1592,8 +1601,12 @@ void NDArray::Save(dmlc::Stream *strm) const { save_data = nd_cpu.data(); } else { this->WaitToRead(); - save_data = this->data(); nd_cpu = *this; +#if MXNET_USE_MKLDNN == 1 + if (nd_cpu.IsMKLDNNData()) + nd_cpu = nd_cpu.Reorder2Default(); +#endif + save_data = nd_cpu.data(); } // save type flag diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index c4a4e52480b6..8a2f4a3e5011 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -132,6 +132,11 @@ static inline bool SupportMKLDNN(int dtype, const TShape &shape) { return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4); } +static inline bool SupportMKLDNNQuantize(int dtype) { + return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 || + dtype == mshadow::kUint8; +} + static inline bool SupportMKLDNN(const NDArray &input) { return SupportMKLDNN(input.dtype(), input.shape()) && SupportStorageMKLDNN(input.storage_type()); @@ -186,6 +191,23 @@ static inline mkldnn::memory::data_type get_mkldnn_type(int dtype) { } } +static inline int get_mxnet_type(mkldnn_data_type_t dtype) { + auto mkldnn_dtype = static_cast(dtype); + switch (mkldnn_dtype) { + case mkldnn::memory::data_type::f32: + return mshadow::kFloat32; + case mkldnn::memory::data_type::s32: + return mshadow::kInt32; + case mkldnn::memory::data_type::s8: + return mshadow::kInt8; + case mkldnn::memory::data_type::u8: + return mshadow::kUint8; + default: + LOG(FATAL) << "unknown MKLDNN type"; + return mshadow::kFloat32; + } +} + inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int ndim) { mkldnn::memory::dims dims(ndim); for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; @@ -327,6 +349,24 @@ enum OutDataOp { typedef std::pair mkldnn_output_t; void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem); +/* + * Here we want to get MKLDNN memory whose primitive desc is exactly the same as + * the given one. operator== can't guarantee that. == can return true even if + * the formats are different. I need to double check its format. + */ +static inline mkldnn::memory *GetMKLDNNExact( + const mkldnn::memory *mem, mkldnn::memory::primitive_desc desc) { + mkldnn::memory::primitive_desc src_desc = mem->get_primitive_desc(); + if (desc == src_desc && desc.desc().data.format == src_desc.desc().data.format) { + return const_cast(mem); + } else { + std::shared_ptr ret(new mkldnn::memory( + desc, mem->get_data_handle())); + MKLDNNStream::Get()->RegisterMem(ret); + return ret.get(); + } +} + /* * These two functions try to create MKLDNN memory in an NDArray based on `req'. * The difference is that the first function can create MKLDNN memory with diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 029f23bd8f5e..a60d6555c74d 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -332,6 +332,7 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) { } else if (desc.data.ndims == 5) { switch (desc.data.format) { case mkldnn_goihw: + case mkldnn_hwigo: case mkldnn_gOIhw8i8o: case mkldnn_gOIhw16i16o: case mkldnn_gOIhw4i16o4i: diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index 23f2fe694633..971c66ad9dd2 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -27,6 +27,7 @@ #if MXNET_USE_MKLDNN == 1 +#include #include #include "../convolution-inl.h" #include "./mkldnn_ops-inl.h" @@ -35,19 +36,68 @@ namespace mxnet { namespace op { -mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( - const ConvolutionParam& param, const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output); +struct MKLDNNConvParam : public dmlc::Parameter { + bool with_bn; + bool with_relu; + bool with_sum; + bool with_postsum_relu; + bool quantized; + bool weight_channelwise_scale; + + dmlc::optional min_calib_range; // min float value calculated from calibration dataset + dmlc::optional max_calib_range; // max float value calculated from calibration dataset + + DMLC_DECLARE_PARAMETER(MKLDNNConvParam) { + DMLC_DECLARE_FIELD(with_bn).set_default(false) + .describe("Add post batchnorm."); + DMLC_DECLARE_FIELD(with_relu).set_default(false) + .describe("Add post relu"); + DMLC_DECLARE_FIELD(with_sum).set_default(false) + .describe("Add post sum"); + DMLC_DECLARE_FIELD(with_postsum_relu).set_default(false) + .describe("Add post relu after sum"); + DMLC_DECLARE_FIELD(quantized).set_default(false) + .describe("enable quantization"); + DMLC_DECLARE_FIELD(weight_channelwise_scale).set_default(true) + .describe("Quantize weight with channel wise scales."); + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe("The minimum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized convolution op to calculate primitive scale"); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe("The maximum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized convolution op to calculate primitive scale"); + } +}; + +struct MKLDNNConvFullParam { + ConvolutionParam conv_param; + MKLDNNConvParam mkldnn_param; + float sum_scale; + std::vector requantize_scales; +}; + +static inline bool IsOutputUInt8(const MKLDNNConvParam &mkldnn_param) { + return ((!mkldnn_param.with_sum) && mkldnn_param.with_relu) || + mkldnn_param.with_postsum_relu; +} + +mkldnn::convolution_forward::primitive_desc +GetConvFwdImpl(const MKLDNNConvFullParam ¶m, const bool is_train, + const NDArray &data, const NDArray &weights, const NDArray *bias, + const NDArray &output); class MKLDNNConvForward { public: mkldnn::convolution_forward::primitive_desc fwd_pd; - MKLDNNConvForward(const ConvolutionParam& param, const bool is_train, + MKLDNNConvForward(const MKLDNNConvFullParam ¶m, const bool is_train, const NDArray &data, const NDArray &weights, - const NDArray *bias, const NDArray &output): fwd_pd( - GetConvFwdImpl(param, is_train, data, weights, bias, output)) { - } + const NDArray *bias, const NDArray &output) + : fwd_pd(GetConvFwdImpl(param, is_train, data, weights, bias, output)) {} void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, const mkldnn::memory *bias, const mkldnn::memory &output); @@ -66,9 +116,17 @@ class MKLDNNConvForward { typedef ParamOpSign MKLDNNConvSignature; -MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs, - const bool is_train, const NDArray &data, const NDArray &weights, - const NDArray *bias, const NDArray &output); +MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, + const bool is_train, const NDArray &data, + const NDArray &weights, const NDArray *bias, + const NDArray &output); + +void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, + const OpContext &ctx, + MKLDNNConvForward *fwd, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); } // namespace op } // namespace mxnet diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 2e19d3219abb..6a70ae40ac8f 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -34,55 +34,83 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(MKLDNNConvParam); + bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { if (params.kernel.ndim() != 2) return false; - return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; + return SupportMKLDNNQuantize(input.dtype()) && input.shape().ndim() == 4; } mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( - const ConvolutionParam& param, const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output) { + const MKLDNNConvFullParam ¶m, const bool is_train, + const NDArray &data, const NDArray &weights, const NDArray *bias, + const NDArray &output) { auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.num_group); + auto weight_md = GetWeightDesc(weights, param.conv_param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); - CHECK_GE(param.stride.ndim(), 2U); - CHECK_GE(param.pad.ndim(), 2U); - CHECK_GE(param.dilate.ndim(), 2U); + CHECK_GE(param.conv_param.stride.ndim(), 2U); + CHECK_GE(param.conv_param.pad.ndim(), 2U); + CHECK_GE(param.conv_param.dilate.ndim(), 2U); mkldnn::memory::dims strides{0, 0}; - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; + strides[0] = param.conv_param.stride[0]; + strides[1] = param.conv_param.stride[1]; mkldnn::memory::dims padding{0, 0}; - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - if (param.dilate.ndim() == 0 && bias == nullptr) { + padding[0] = param.conv_param.pad[0]; + padding[1] = param.conv_param.pad[1]; + mkldnn::primitive_attr attr; + mkldnn::post_ops ops; + if (param.mkldnn_param.with_relu) { + float scale = 1.0f; // for fp32, scale is 1. + float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. + float beta = 1.0f; // ignored for mkldnn_eltwise_relu. + ops.append_eltwise(scale, eltwise_relu, alpha, beta); + } + if (param.mkldnn_param.with_sum) { + ops.append_sum(param.sum_scale); + } + if (param.mkldnn_param.with_postsum_relu) { + float scale = 1.0f; // for fp32, scale is 1. + float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. + float beta = 1.0f; // ignored for mkldnn_eltwise_relu. + ops.append_eltwise(scale, eltwise_relu, alpha, beta); + } + attr.set_post_ops(ops); + + if (param.mkldnn_param.quantized && param.requantize_scales.size()) { + int mask = param.mkldnn_param.weight_channelwise_scale ? 2 : 0; + attr.set_output_scales(mask, param.requantize_scales); + attr.set_int_output_round_mode(round_nearest); + } + + if (param.conv_param.dilate.ndim() == 0 && bias == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, engine); - } else if (param.dilate.ndim() == 0) { + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); + } else if (param.conv_param.dilate.ndim() == 0) { auto bias_md = GetMemDesc(*bias); mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, bias_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, engine); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); } else { mkldnn::memory::dims dilates{0, 0}; - dilates[0] = param.dilate[0] - 1; - dilates[1] = param.dilate[1] - 1; + dilates[0] = param.conv_param.dilate[0] - 1; + dilates[1] = param.conv_param.dilate[1] - 1; if (bias == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, dilates, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, engine); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); } else { auto bias_md = GetMemDesc(*bias); mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, bias_md, out_md, strides, dilates, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, engine); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); } } } @@ -207,15 +235,15 @@ void MKLDNNConvForward::SetNewMem(const mkldnn::memory &data, } } -MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs, const bool is_train, - const NDArray &data, const NDArray &weights, - const NDArray *bias, const NDArray &output) { +MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, + const bool is_train, const NDArray &data, + const NDArray &weights, const NDArray *bias, + const NDArray &output) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; #else static MX_THREAD_LOCAL std::unordered_map fwds; #endif - const ConvolutionParam& param = nnvm::get(attrs.parsed); MKLDNNConvSignature key(param); key.AddSign(is_train); // Here we can sign the conv op with NDArray because conv primitive will @@ -229,7 +257,10 @@ MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs, const bool is_train, auto it = fwds.find(key); if (it == fwds.end()) { - MKLDNNConvForward fwd(param, is_train, data, weights, bias, output); + MKLDNNConvFullParam full_param; + full_param.conv_param = param; + full_param.mkldnn_param.Init(std::unordered_map()); + MKLDNNConvForward fwd(full_param, is_train, data, weights, bias, output); auto ins_ret = fwds.insert( std::pair(key, fwd)); CHECK(ins_ret.second); @@ -238,17 +269,17 @@ MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs, const bool is_train, return it->second; } -void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data) { +void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, + const OpContext &ctx, + MKLDNNConvForward *fwd, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); - const ConvolutionParam& param = nnvm::get(attrs.parsed); NDArray weight = in_data[conv::kWeight]; - MKLDNNConvForward &fwd = GetConvFwd(attrs, ctx.is_train, in_data[conv::kData], weight, - param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); - - auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc()); + bool no_bias = param.conv_param.no_bias && !param.mkldnn_param.with_bn; + auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder( + fwd->fwd_pd.src_primitive_desc()); const mkldnn::memory *weight_mem; if (ctx.is_train) { // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it @@ -257,32 +288,58 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx // This asks the engine to change the layout of the weight array after // it's used. weight.Reorder2DefaultAsync(); - weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group); + weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), + param.conv_param.num_group); } else { // For inference, we want to reorder the weight array so we don't need to // reorder data every time. if (weight.IsDefaultData()) { - weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group); + weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), + param.conv_param.num_group); // We also need to modify the layout on the original weight array. The // data conversion happens after the weight array is used. - weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc()); + weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc()); } else { weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc()); + CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc()); } } - auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(), - req[conv::kOut]); + mkldnn_output_t out_mem; + if (param.mkldnn_param.with_sum) { + out_mem = mkldnn_output_t( + OutDataOp::Noop, + const_cast(out_data[conv::kOut].GetMKLDNNData())); + } else { + out_mem = CreateMKLDNNMem(out_data[conv::kOut], + fwd->fwd_pd.dst_primitive_desc(), req[conv::kOut]); + } + const mkldnn::memory *bias_mem = nullptr; - if (!param.no_bias) - bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc()); - fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + if (!no_bias) { + bias_mem = in_data[conv::kBias].GetMKLDNNData(); + } + fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); + MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd()); CommitOutput(out_data[conv::kOut], out_mem); MKLDNNStream::Get()->Submit(); } +void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + MKLDNNConvFullParam param; + param.conv_param = nnvm::get(attrs.parsed); + param.mkldnn_param.Init(std::unordered_map()); + auto &fwd = GetConvFwd( + param.conv_param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], + param.conv_param.no_bias ? nullptr : &in_data[conv::kBias], + out_data[conv::kOut]); + MKLDNNConvolutionForwardFullFeature(param, ctx, &fwd, in_data, req, out_data); +} + class MKLDNNConvBackward { std::shared_ptr bwd_data; std::shared_ptr bwd_weight; @@ -440,10 +497,14 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct const std::vector& outputs) { TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); const std::vector &in_grad = outputs; - const ConvolutionParam& param = nnvm::get(attrs.parsed); - mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl(param, ctx.is_train, - inputs[conv::kData + 1], inputs[conv::kWeight + 1], - param.no_bias ? nullptr : &inputs[conv::kBias + 1], inputs[conv::kOut]); + MKLDNNConvFullParam full_param; + full_param.conv_param = nnvm::get(attrs.parsed); + full_param.mkldnn_param.Init(std::unordered_map()); + mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl( + full_param, ctx.is_train, inputs[conv::kData + 1], inputs[conv::kWeight + 1], + full_param.conv_param.no_bias ? nullptr : &inputs[conv::kBias + 1], + inputs[conv::kOut]); + const ConvolutionParam ¶m = full_param.conv_param; CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace"; MKLDNNConvBackward &convBwd = GetConvBwd(attrs, inputs[conv::kData + 1], diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h index f7709319d6a2..7a00f621d452 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h @@ -75,6 +75,11 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, auto i_mpd = i_mem->get_primitive_desc(); auto i_desc = i_mpd.desc(); mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); + if (i_fmt == mkldnn::memory::format::nchw || + i_fmt == mkldnn::memory::format::nChw8c || + i_fmt == mkldnn_nChw16c) { + i_fmt = mkldnn::memory::format::nhwc; + } size_t i_ndim = in_buffer.shape().ndim(); mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); for (size_t i = 0; i < i_ndim; i++) { diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc index fa6a32a47392..b8c47c3af11b 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc @@ -41,12 +41,12 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_data[0].dtype(), mshadow::kUint8) << "mkldnn_quantized_conv op only supports uint8 as input type"; TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); - const ConvolutionParam& param = nnvm::get(attrs.parsed); NDArray weight = in_data[conv::kWeight]; - MKLDNNConvForward &fwd = GetConvFwd(attrs, ctx.is_train, - in_data[conv::kData], weight, - param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); - + ConvolutionParam param = nnvm::get(attrs.parsed); + auto &fwd = GetConvFwd( + param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], + param.no_bias ? nullptr : &in_data[conv::kBias], + out_data[conv::kOut]); auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc()); const mkldnn::memory *weight_mem; // For inference, we want to reorder the weight array so we don't need to diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 10834868d2b5..2fa790dc88ef 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -89,17 +89,42 @@ std::vector OfflineParams(std::vector&& outputs, return outputs; } -inline bool NeedQuantize(NodePtr node, const std::unordered_set excluded_nodes) { +inline bool NeedQuantize(NodePtr node, const std::unordered_set& excluded_nodes) { static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); - return quantized_op_map.count(node->op()) && !excluded_nodes.count(node); + static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); + const auto& op = node->op(); + if (op && quantized_op_map.count(op)) { + bool need = true; + if (excluded_nodes.count(node->attrs.name)) { + need = false; + } else if (!node->attrs.subgraphs.empty()) { + ExecType exec_type = fexec_type.count(op) ? fexec_type[op](node->attrs) : ExecType::kSync; + if (exec_type != ExecType::kSubgraphExec) { + // This is a fused subgraph node, try to match inner node. + CHECK_EQ(node->attrs.subgraphs.size(), 1); + auto subgraph_sym = node->attrs.subgraphs[0]; + DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr& n) { + if (n->is_variable()) return; + if (excluded_nodes.count(n->attrs.name)) { + need = false; + } + }); + } + } + return need; + } + return false; } Graph QuantizeGraph(Graph &&src) { static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); static auto& need_requantize_map = Op::GetAttr("FNeedRequantize"); + static auto& avoid_quantize_input_map = + Op::GetAttr("FAvoidQuantizeInput"); auto offline_params = src.GetAttr>("offline_params"); - auto excluded_nodes = src.GetAttr>("excluded_nodes"); + auto excluded_nodes = src.GetAttr>("excluded_nodes"); auto quantized_dtype = src.GetAttr("quantized_dtype"); + auto calib_quantize = src.GetAttr("calib_quantize"); // mirror_map stores the mapping from the currently visited graph to the newly created quantized // graph. Key is the currently visited graph's node pointer, and value is a copied node of the key @@ -116,7 +141,8 @@ Graph QuantizeGraph(Graph &&src) { new_node = fquantized_op(node->attrs); // add data into quantized op input - for (const auto& e : node->inputs) { + for (size_t i = 0; i < node->inputs.size(); ++i) { + const auto& e = node->inputs[i]; NodePtr mirror_node = mirror_map.at(e.node.get()); NodeEntry mirror_entry = NodeEntry{ mirror_node, e.index, e.version}; @@ -125,23 +151,34 @@ Graph QuantizeGraph(Graph &&src) { // taking mirror_entry as input to generate a quantized NDArray. Save the mapping between // e's source node and the newly created quantize op so that the quantize op can be // reused next time when the same entry is visited again. - if (!NeedQuantize(e.node, excluded_nodes) && - (mirror_node->op() == nullptr || - mirror_node->op()->name != "_contrib_quantize")) { + if (avoid_quantize_input_map.count(node->op()) && + avoid_quantize_input_map[node->op()](node->attrs, i)) { + new_node->inputs.emplace_back(mirror_entry); + } else if (!NeedQuantize(e.node, excluded_nodes) && + (mirror_node->op() == nullptr || + mirror_node->op()->name != "_contrib_quantize")) { NodePtr quantize_node = InsertNode("_contrib_quantize", e.node->attrs.name + "_quantize", new_node, mirror_entry); quantize_node->attrs.dict["out_type"] = quantized_dtype; quantize_node->op()->attr_parser(&(quantize_node->attrs)); + if (calib_quantize) { + NodePtr min_var = CreateNode("nullptr", e.node->attrs.name + "_min"); + quantize_node->inputs.emplace_back(NodeEntry{min_var, 0, 0}); + NodePtr max_var = CreateNode("nullptr", e.node->attrs.name + "_max"); + quantize_node->inputs.emplace_back(NodeEntry{max_var, 0, 0}); + } else { + NodePtr min_node = InsertNode("min", + e.node->attrs.name + "_min", quantize_node, mirror_entry); + min_node->op()->attr_parser(&(min_node->attrs)); - NodePtr min_node = InsertNode("min", - e.node->attrs.name + "_min", quantize_node, mirror_entry); - min_node->op()->attr_parser(&(min_node->attrs)); - - NodePtr max_node = InsertNode("max", - e.node->attrs.name + "_max", quantize_node, mirror_entry); - max_node->op()->attr_parser(&(max_node->attrs)); - + NodePtr max_node = InsertNode("max", + e.node->attrs.name + "_max", quantize_node, mirror_entry); + max_node->op()->attr_parser(&(max_node->attrs)); + } mirror_map[e.node.get()] = std::move(quantize_node); + } else if (mirror_node->op() != nullptr + && mirror_node->op()->name == "_contrib_dequantize") { + new_node->inputs.emplace_back(NodeEntry{mirror_node->inputs[0].node, e.index, e.version}); } else { // If the entry e's node needs quantization, or mirror_entry is from a quantize op, // simply add mirror_entry to the input of the new_node. @@ -152,24 +189,35 @@ Graph QuantizeGraph(Graph &&src) { // add min and max into quantized op input assume order of quantized op inputs is: // data1, data2, ..., min1, max1, min2, max2, ... - for (const auto& e : node->inputs) { + for (size_t i = 0; i < node->inputs.size(); ++i) { + const auto& e = node->inputs[i]; NodePtr mirror_node = mirror_map.at(e.node.get()); + if (mirror_node->op() != nullptr + && mirror_node->op()->name == "_contrib_dequantize") { + mirror_node = mirror_node->inputs[0].node; + } NodeEntry mirror_entry = NodeEntry{ mirror_node, e.index, e.version}; // for quantize node uint32_t min_index = 1; uint32_t max_index = 2; + if (avoid_quantize_input_map.count(node->op()) && + avoid_quantize_input_map[node->op()](node->attrs, i)) { + // skip non-quantized input + continue; + } if (quantized_op_map.count(e.node->op())) { // here we calculate the output number (exclude min/max, in order to // calculate min/max index from mirror node) based on assumption that // there is only 1min and 1max output from mirror node (which is // currently true) - size_t num_outputs = mirror_node->num_outputs() - 2; + size_t num_outputs = mirror_node->num_outputs() - 2; min_index = num_outputs + 2 * e.index; max_index = num_outputs + 2 * e.index + 1; } else { - CHECK(mirror_node->op()->name == "_contrib_quantize") - << "The input is not quantize or quantized_op"; + CHECK(mirror_node->op() != nullptr && + mirror_node->op()->name == "_contrib_quantize") + << "The input is not quantize or quantized_op"; } new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); @@ -178,8 +226,8 @@ Graph QuantizeGraph(Graph &&src) { // If the new_node op registered attr FNeedRequantize, insert requantize node after it. // Here it's assumed that the quantized_op node only produces three outputs: // out_data, min_range, and max_range. - if (need_requantize_map.count(new_node->op()) > 0 - && need_requantize_map[new_node->op()](new_node->attrs)) { + if (need_requantize_map.count(new_node->op()) > 0 && + need_requantize_map[new_node->op()](new_node->attrs)) { NodePtr requantize_node = Node::Create(); requantize_node->attrs.op = Op::Get("_contrib_requantize"); requantize_node->attrs.name = "requantize_" + node->attrs.name; @@ -187,7 +235,8 @@ Graph QuantizeGraph(Graph &&src) { requantize_node->op()->attr_parser(&(requantize_node->attrs)); } for (size_t i = 0; i < 3; ++i) { - requantize_node->inputs.emplace_back(NodeEntry{new_node, static_cast(i), 0}); + requantize_node->inputs.emplace_back( + NodeEntry{new_node, static_cast(i), 0}); } new_node = requantize_node; } @@ -199,33 +248,45 @@ Graph QuantizeGraph(Graph &&src) { // the new_node. *new_node = *node; new_node->inputs.clear(); - for (const auto& e : node->inputs) { - NodePtr mirror_node = mirror_map.at(e.node.get()); - NodeEntry mirror_entry = NodeEntry{ - mirror_node, e.index, e.version}; - // if input node is quantized operator, add dequantize node - if (NeedQuantize(e.node, excluded_nodes)) { - // here we calculate the output number (exclude min/max, in order to - // calculate min/max index from mirror node) based on assumption that - // there is only 1min and 1max output from mirror node (which is - // currently true) - size_t num_outputs = mirror_node->num_outputs() - 2; - uint32_t min_index = num_outputs + 2 * e.index; - uint32_t max_index = num_outputs + 2 * e.index + 1; - NodePtr dequantize_node = CreateNode("_contrib_dequantize", - e.node->attrs.name + "_dequantize"); - dequantize_node->inputs.emplace_back(mirror_entry); - dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); - dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); - dequantize_node->op()->attr_parser(&(dequantize_node->attrs)); + if (node->is_variable() && node->attrs.name == "data") { + // Insert identity for data to collect calib for it. + NodePtr identity_node = + CreateNode("identity", new_node->attrs.name + "_id"); + identity_node->inputs.emplace_back(NodeEntry{new_node, 0, 0}); + new_node = identity_node; + } else { + for (const auto& e : node->inputs) { + NodePtr mirror_node = mirror_map.at(e.node.get()); + NodeEntry mirror_entry = NodeEntry{ + mirror_node, e.index, e.version}; + // if input node is quantized operator, add dequantize node + if (NeedQuantize(e.node, excluded_nodes) && + (mirror_node->op() == nullptr || + mirror_node->op()->name != "_contrib_dequantize")) { + // here we calculate the output number (exclude min/max, in order to + // calculate min/max index from mirror node) based on assumption that + // there is only 1min and 1max output from mirror node (which is + // currently true) + size_t num_outputs = mirror_node->num_outputs() - 2; + uint32_t min_index = num_outputs + 2 * e.index; + uint32_t max_index = num_outputs + 2 * e.index + 1; + NodePtr dequantize_node = CreateNode("_contrib_dequantize", + e.node->attrs.name + "_dequantize"); + dequantize_node->inputs.emplace_back(mirror_entry); + dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); + dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); + dequantize_node->op()->attr_parser(&(dequantize_node->attrs)); - new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0}); - mirror_map[e.node.get()] = std::move(dequantize_node); - } else if (mirror_node->op() != nullptr - && mirror_node->op()->name == "_contrib_quantize") { - new_node->inputs.emplace_back(NodeEntry{mirror_node->inputs[0].node, e.index, e.version}); - } else { - new_node->inputs.emplace_back(NodeEntry{mirror_node, e.index, e.version}); + new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0}); + mirror_map[e.node.get()] = std::move(dequantize_node); + } else if (mirror_node->op() != nullptr + && mirror_node->op()->name == "_contrib_quantize") { + new_node->inputs.emplace_back( + NodeEntry{mirror_node->inputs[0].node, e.index, e.version}); + } else { + new_node->inputs.emplace_back( + NodeEntry{mirror_node, e.index, e.version}); + } } } } diff --git a/src/operator/subgraph/default_subgraph_property.cc b/src/operator/subgraph/default_subgraph_property.cc index 0152344f4d43..5a2c52e61729 100644 --- a/src/operator/subgraph/default_subgraph_property.cc +++ b/src/operator/subgraph/default_subgraph_property.cc @@ -17,8 +17,6 @@ * under the License. */ -#include -#include #include "./common.h" #include "./subgraph_property.h" #include "../../imperative/cached_op.h" diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h new file mode 100644 index 000000000000..8675446f5a14 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_INL_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_INL_H_ +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include +#include "../../nn/convolution-inl.h" +#include "../../nn/batch_norm-inl.h" +#include "../../nn/mkldnn/mkldnn_convolution-inl.h" + +namespace mxnet { +namespace op { + +struct MKLDNNConvFusionParam { + MKLDNNConvFullParam full_conv_param; + std::shared_ptr bn_param; +}; + +static const size_t uint8_range = 255; +static const size_t int8_range = 127; + +enum MKLDNNConvOpOutputs { kOut, kMin, kMax }; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_INL_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc new file mode 100644 index 000000000000..a1083d09b7b5 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -0,0 +1,690 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include +#include "../common.h" +#include "../../nn/mkldnn/mkldnn_base-inl.h" +#include "../../nn/mkldnn/mkldnn_ops-inl.h" +#include "../../quantization/quantization_utils.h" +#include "mkldnn_conv-inl.h" + +namespace mxnet { +namespace op { + +template +static void UpdateConvWeightBias(NDArray *weight, NDArray *bias, bool no_bias, + const NDArray &gamma, const NDArray &beta, + const NDArray &mean, const NDArray &variance, + const BatchNormParam *param) { + // TODO(Zhennan): Handle the case weight is not in dims 4. + NDArray update_weight = NDArray(weight->storage_type(), weight->shape(), + weight->ctx(), true, weight->dtype()); + NDArray update_bias = NDArray(beta.storage_type(), beta.shape(), beta.ctx(), + true, beta.dtype()); + const DType *weight_ptr = weight->data().dptr(); + const DType *bias_ptr = no_bias ? nullptr : bias->data().dptr(); + const DType *gamma_ptr = gamma.Reorder2Default().data().dptr(); + const DType *beta_ptr = beta.Reorder2Default().data().dptr(); + const DType *mean_ptr = mean.Reorder2Default().data().dptr(); + const DType *var_ptr = variance.Reorder2Default().data().dptr(); + DType *update_weight_ptr = update_weight.data().dptr(); + DType *update_bias_ptr = update_bias.data().dptr(); + size_t channel = gamma.shape()[0]; + size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3]; +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + for (int c = 0; c < static_cast(channel); ++c) { + const DType *p1 = weight_ptr + c * offset; + DType *p2 = update_weight_ptr + c * offset; + DType alpha = (param->fix_gamma ? static_cast(1.0f) : gamma_ptr[c]) / + sqrt(var_ptr[c] + param->eps); + + if (bias_ptr) + update_bias_ptr[c] = beta_ptr[c] + alpha * (bias_ptr[c] - mean_ptr[c]); + else + update_bias_ptr[c] = beta_ptr[c] - alpha * mean_ptr[c]; + + for (size_t k = 0; k < offset; ++k) { + p2[k] = p1[k] * alpha; + } + } + *weight = update_weight; + *bias = update_bias; +} + +static inline size_t GetInSumIndex(const MKLDNNConvFusionParam ¶m) { + return 2 + (param.full_conv_param.conv_param.no_bias ? 0 : 1) + + (param.full_conv_param.mkldnn_param.with_bn ? 4 : 0); +} + +template +static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, + bool has_bias, float data_scale, + bool weight_channelwise_scale, + std::vector *weight_scales) { + using red::limits::MaxValue; + using red::limits::MinValue; + const DType *weight_ptr = weight->data().dptr(); + NDArray quantized_weight = NDArray(weight->storage_type(), weight->shape(), + weight->ctx(), true, mshadow::kInt8); + int8_t *quan_weight_ptr = quantized_weight.data().dptr(); + size_t channel = weight->shape()[0]; + + // TODO(Zhennan): Handle the case weight is not in dims 4. + size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3]; + std::vector weight_c_min(channel, MaxValue()); + std::vector weight_c_max(channel, MinValue()); +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + for (int c = 0; c < static_cast(channel); ++c) { + const DType *p1 = weight_ptr + c * offset; + for (size_t k = 0; k < offset; ++k) { + if (weight_c_min[c] > p1[k]) + weight_c_min[c] = p1[k]; + if (weight_c_max[c] < p1[k]) + weight_c_max[c] = p1[k]; + } + } + + if (weight_channelwise_scale) { + weight_scales->resize(channel); +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + for (int c = 0; c < static_cast(channel); ++c) { + DType weight_range = MaxAbs(weight_c_min[c], weight_c_max[c]); + weight_scales->at(c) = int8_range / weight_range; + const DType *fp_ptr = weight_ptr + c * offset; + int8_t *quan_ptr = quan_weight_ptr + c * offset; + for (size_t k = 0; k < offset; ++k) { + quan_ptr[k] = std::round(weight_scales->at(c) * fp_ptr[k]); + } + } + } else { + DType total_min = weight_c_min[0]; + DType total_max = weight_c_max[0]; + for (size_t c = 0; c < channel; ++c) { + if (total_min > weight_c_min[c]) total_min = weight_c_min[c]; + if (total_max < weight_c_max[c]) total_max = weight_c_max[c]; + } + weight_scales->resize(1); + DType weight_range = MaxAbs(total_min, total_max); + weight_scales->at(0) = int8_range / weight_range; +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + for (int c = 0; c < static_cast(channel); ++c) { + const DType *fp_ptr = weight_ptr + c * offset; + int8_t *quan_ptr = quan_weight_ptr + c * offset; + for (size_t k = 0; k < offset; ++k) { + quan_ptr[k] = std::round(weight_scales->at(0) * fp_ptr[k]); + } + } + } + + *weight = quantized_weight; + if (has_bias) { + const DType *bias_ptr = bias->data().dptr(); + NDArray quantized_bias = NDArray(bias->storage_type(), bias->shape(), + bias->ctx(), true, mshadow::kInt32); + int32_t *quan_bias_ptr = quantized_bias.data().dptr(); + for (size_t c = 0; c < channel; ++c) { + auto weight_scale = + weight_channelwise_scale ? weight_scales->at(c) : weight_scales->at(0); + float bias_scale = weight_scale * data_scale; + quan_bias_ptr[c] = std::round(bias_scale * bias_ptr[c]); + } + *bias = quantized_bias; + } +} + +static void ConvFusionFallBackCompute() { + LOG(FATAL) << "Don't know how to do ConvFusionFallBackCompute!"; +} + +static void ConvolutionFusionComputeExCPU(const MKLDNNConvFullParam &full_param, + const OpContext &ctx, + MKLDNNConvForward *fwd, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + if (SupportMKLDNNConv(full_param.conv_param, inputs[0])) { + MKLDNNConvolutionForwardFullFeature(full_param, ctx, fwd, inputs, req, outputs); + return; + } + ConvFusionFallBackCompute(); +} + +class SgMKLDNNConvOperator { + public: + explicit SgMKLDNNConvOperator(const nnvm::NodeAttrs &attrs) + : initalized_(false), + subgraph_sym_(*attrs.subgraphs[0]), + param_(nnvm::get(attrs.parsed)), + inplace_(false) {} + + void Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + void Backward(const OpContext &ctx, const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + LOG(FATAL) << "Not implemented: subgraph mkldnn Conv only supports " + "inference computation."; + } + + private: + bool initalized_; + nnvm::Symbol subgraph_sym_; + MKLDNNConvFusionParam param_; + std::shared_ptr fwd_; + NDArray cached_weight_; + NDArray cached_bias_; + float cached_data_min_; + float cached_data_max_; + float cached_sum_min_; + float cached_sum_max_; + size_t weight_ver_; + size_t bias_ver_; + std::vector weight_scales_; + bool inplace_; +}; + +void SgMKLDNNConvOperator::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + auto &full_conv_param = param_.full_conv_param; + auto &mkldnn_param = param_.full_conv_param.mkldnn_param; + auto &conv_param = param_.full_conv_param.conv_param; + auto bn_param = param_.bn_param.get(); + size_t input_size = + 2 + (conv_param.no_bias ? 0 : 1) + (mkldnn_param.with_bn ? 4 : 0) + + (mkldnn_param.with_sum ? 1 : 0) + + (mkldnn_param.quantized ? 2 + (full_conv_param.mkldnn_param.with_sum ? 2 : 0) : 0); + CHECK_EQ(inputs.size(), input_size); + size_t idx = 0; + + auto in_data = idx++; + auto in_weight = idx++; + auto in_bias = conv_param.no_bias ? 0 : (idx++); + auto in_gamma = mkldnn_param.with_bn ? (idx++) : 0; + auto in_beta = mkldnn_param.with_bn ? (idx++) : 0; + auto in_mean = mkldnn_param.with_bn ? (idx++) : 0; + auto in_var = mkldnn_param.with_bn ? (idx++) : 0; + auto in_sum = mkldnn_param.with_sum ? (idx++) : 0; + float data_min = + mkldnn_param.quantized ? inputs[idx++].data().dptr()[0] : 0.0; + float data_max = + mkldnn_param.quantized ? inputs[idx++].data().dptr()[0] : 0.0; + float sum_min = (mkldnn_param.with_sum && mkldnn_param.quantized) + ? inputs[idx++].data().dptr()[0] + : 0.0; + float sum_max = (mkldnn_param.with_sum && mkldnn_param.quantized) + ? inputs[idx++].data().dptr()[0] + : 0.0; + float *out_min_ptr = + mkldnn_param.quantized ? outputs[kMin].data().dptr() : nullptr; + float *out_max_ptr = + mkldnn_param.quantized ? outputs[kMax].data().dptr() : nullptr; + CHECK_EQ(input_size, idx); + bool has_bias = mkldnn_param.with_bn || !conv_param.no_bias; + NDArray data = inputs[in_data]; + NDArray output = mkldnn_param.with_sum ? inputs[in_sum] : outputs[kOut]; + + // Copy inputs[in_sum] into outputs[kOut] in case inplace optimization failed. + if (mkldnn_param.with_sum) { + if (!initalized_) { + auto in_mkl_mem = inputs[in_sum].GetMKLDNNData(); + auto out_mkl_mem = outputs[kOut].GetMKLDNNData(); + // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace option, + // which make check (req[kOut] == kWriteInplace) useless. + if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) { + inplace_ = true; + } + } + if (!inplace_) { + auto in_mkl_mem = inputs[in_sum].GetMKLDNNData(); + const_cast(outputs[kOut]).CopyFrom(*in_mkl_mem); + output = NDArray(outputs[kOut].GetMKLDNNData()); + } + } + + // Check input change + // TODO(zhennan): Only update cached_* changed. + if (initalized_) { + if (mkldnn_param.with_bn) { + if (weight_ver_ != inputs[in_weight].version() || + ((!conv_param.no_bias) && bias_ver_ != inputs[in_bias].version())) { + initalized_ = false; + } + } + if (initalized_ && mkldnn_param.quantized) { + if (cached_data_min_ != data_min || cached_data_max_ != data_max || + cached_sum_min_ != sum_min || cached_sum_max_ != sum_max || + weight_ver_ != inputs[in_weight].version() || + ((!conv_param.no_bias) && bias_ver_ != inputs[in_bias].version())) { + initalized_ = false; + } + } + } + bool post_requantize = false; + if (mkldnn_param.quantized) { + if (mkldnn_param.min_calib_range.has_value() && + mkldnn_param.max_calib_range.has_value()) { + post_requantize = true; + mkldnn_param.weight_channelwise_scale = true; + *out_min_ptr = mkldnn_param.min_calib_range.value(); + *out_max_ptr = mkldnn_param.max_calib_range.value(); + } else { + mkldnn_param.weight_channelwise_scale = false; + } + } + + if (!initalized_) { + cached_data_min_ = data_min; + cached_data_max_ = data_max; + cached_sum_min_ = sum_min; + cached_sum_max_ = sum_max; + full_conv_param.sum_scale = 1.0; + cached_weight_ = inputs[in_weight].Reorder2Default(); + weight_ver_ = inputs[in_weight].version(); + if (!conv_param.no_bias) { + cached_bias_ = inputs[in_bias].Reorder2Default(); + bias_ver_ = inputs[in_bias].version(); + } else { + cached_bias_ = NDArray(); + } + + // Update weight and bias after bn fusion. + if (mkldnn_param.with_bn) { + CHECK_EQ(inputs[in_weight].dtype(), inputs[in_gamma].dtype()); + CHECK_EQ(inputs[in_weight].dtype(), inputs[in_beta].dtype()); + CHECK_EQ(inputs[in_weight].dtype(), inputs[in_var].dtype()); + MSHADOW_REAL_TYPE_SWITCH(inputs[in_weight].dtype(), DType, { + UpdateConvWeightBias(&cached_weight_, &cached_bias_, + conv_param.no_bias, inputs[in_gamma], + inputs[in_beta], inputs[in_mean], + inputs[in_var], bn_param); + }); + } + // Quantize weight and bias. + if (mkldnn_param.quantized) { + CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8); + auto data_range = (data.dtype() == mshadow::kInt8) ? int8_range : uint8_range; + float data_scale = data_range / MaxAbs(cached_data_min_, cached_data_max_); + MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, { + QuantizeConvWeightBias(&cached_weight_, &cached_bias_, + has_bias, data_scale, + mkldnn_param.weight_channelwise_scale, + &weight_scales_); + }); + // Collect scale. + size_t channel = cached_weight_.shape()[0]; + float sum_in_scale = 1.0; + float out_range; + float quantized_out_range; + float output_scale; + if (cached_data_min_ < 0.0) { + // TODO(zhennan): Support int8 input when mkldnn supports. + LOG(FATAL) << "Can't handle negetive value for QuantizeData"; + } + if (mkldnn_param.with_sum) { + auto quantized_sum_range = cached_sum_min_ < 0 ? int8_range : uint8_range; + sum_in_scale = quantized_sum_range / MaxAbs(cached_sum_min_, cached_sum_max_); + } + if (post_requantize) { + quantized_out_range = + IsOutputUInt8(mkldnn_param) ? uint8_range : int8_range; + out_range = MaxAbs(*out_min_ptr, *out_max_ptr); + output_scale = quantized_out_range / out_range; + full_conv_param.requantize_scales.resize(channel); + for (size_t c = 0; c < channel; c++) { + auto weight_scale = mkldnn_param.weight_channelwise_scale + ? weight_scales_[c] + : weight_scales_[0]; + full_conv_param.requantize_scales[c] = + output_scale / data_scale / weight_scale; + } + } else { + output_scale = data_scale * weight_scales_[0]; + full_conv_param.requantize_scales.resize(0); + } + if (mkldnn_param.with_sum) + full_conv_param.sum_scale = output_scale / sum_in_scale; + } + fwd_.reset(new MKLDNNConvForward( + full_conv_param, ctx.is_train, data, cached_weight_, + has_bias ? &cached_bias_ : nullptr, output)); + } + initalized_ = true; + std::vector new_inputs; + std::vector new_req; + if (has_bias) { + new_inputs = {data, cached_weight_, cached_bias_}; + new_req = {req[in_data], req[in_weight], req[in_bias]}; + } else { + new_inputs = {data, cached_weight_}; + new_req = {req[in_data], req[in_weight]}; + } + ConvolutionFusionComputeExCPU(full_conv_param, ctx, fwd_.get(), new_inputs, + new_req, {output}); + + if (mkldnn_param.with_sum) { + auto out = const_cast(outputs[kOut]); + out.UpdateMKLDNNMemDesc(); + } +} + +static void SgMKLDNNConvOpForward(const OpStatePtr &state_ptr, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + SgMKLDNNConvOperator &op = state_ptr.get_state(); + op.Forward(ctx, inputs, req, outputs); +} + +static uint32_t SgMKLDNNConvNumInputs(const NodeAttrs &attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + auto num_input = DefaultSubgraphOpNumInputs(attrs); + if (param.full_conv_param.mkldnn_param.quantized) + return num_input + 2 + param.full_conv_param.mkldnn_param.with_sum ? 2 : 0; + else + return num_input; +} + +static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) { + MKLDNNConvFusionParam param_; + try { + param_.full_conv_param.mkldnn_param.Init(attrs->dict); + } catch (const dmlc::ParamError &e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto &k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + auto subgraph_sym = attrs->subgraphs[0]; + DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + auto &node_name = node->op()->name; + if (node_name == "BatchNorm") { + CHECK_EQ(param_.full_conv_param.mkldnn_param.with_bn, true); + CHECK(param_.bn_param.get() == nullptr); + param_.bn_param = std::make_shared( + nnvm::get(node->attrs.parsed)); + } else if (node_name == "Convolution") { + param_.full_conv_param.conv_param = + nnvm::get(node->attrs.parsed); + } + }); + attrs->parsed = std::move(param_); +} + +static std::vector SgMKLDNNConvListInputNames( + const NodeAttrs &attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + std::vector input_names = DefaultSubgraphOpListInputs(attrs); + if (param.full_conv_param.mkldnn_param.quantized) { + input_names.emplace_back("data_min"); + input_names.emplace_back("data_max"); + if (param.full_conv_param.mkldnn_param.with_sum) { + input_names.emplace_back("sum_min"); + input_names.emplace_back("sum_max"); + } + } + return input_names; +} + +static std::vector SgMKLDNNConvListOutputNames( + const NodeAttrs &attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.full_conv_param.mkldnn_param.quantized) + return std::vector{"output", "output_min", "output_max"}; + else + return std::vector{"output"}; +} + +static OpStatePtr CreateSgMKLDNNConvState(const nnvm::NodeAttrs &attrs, + Context ctx, + const std::vector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); +} + +template +static void FilterMinMaxIndice(const MKLDNNConvParam &mkldnn_param, + std::vector *in_shapes, + std::vector *out_shapes, + std::vector *base_in_shapes, + std::vector *base_out_shapes, + std::unordered_set *minmax_indice) { + base_out_shapes->push_back(out_shapes->at(0)); + size_t last = in_shapes->size() - 1; + if (mkldnn_param.with_sum) { + minmax_indice->insert(last); + minmax_indice->insert(last - 1); + minmax_indice->insert(last - 2); + minmax_indice->insert(last - 3); + *base_in_shapes = + std::vector(in_shapes->begin(), in_shapes->end() - 4); + } else { + minmax_indice->insert(last); + minmax_indice->insert(last - 1); + *base_in_shapes = + std::vector(in_shapes->begin(), in_shapes->end() - 2); + } +} + +static bool SgMKLDNNConvInferShape(const nnvm::NodeAttrs &attrs, + std::vector *in_shapes, + std::vector *out_shapes) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.full_conv_param.mkldnn_param.quantized) { + std::unordered_set minmax_indice; + std::vector base_in_shapes; + std::vector base_out_shapes; + + FilterMinMaxIndice(param.full_conv_param.mkldnn_param, in_shapes, + out_shapes, &base_in_shapes, &base_out_shapes, + &minmax_indice); + bool result = + DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes); + size_t base_idx = 0; + for (size_t i = 0; i < in_shapes->size(); ++i) { + if (minmax_indice.count(i)) { + SHAPE_ASSIGN_CHECK(*in_shapes, i, Shape1(1)); + } else { + in_shapes->at(i) = base_in_shapes[base_idx++]; + } + } + out_shapes->at(0) = base_out_shapes[0]; + SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1)); + SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1)); + return result; + } else { + return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes); + } +} + +static bool SgMKLDNNConvInferType(const nnvm::NodeAttrs &attrs, + std::vector *in_types, + std::vector *out_types) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.full_conv_param.mkldnn_param.quantized) { + std::unordered_set minmax_indice; + std::vector base_in_types; + std::vector base_out_types; + FilterMinMaxIndice(param.full_conv_param.mkldnn_param, in_types, + out_types, &base_in_types, &base_out_types, + &minmax_indice); + // Override data type to fp32 for default infer type as bn doesn't support + // uint8. + int orig_data = base_in_types[0]; + base_in_types[0] = mshadow::kFloat32; + int orig_sum = base_in_types[0]; + if (param.full_conv_param.mkldnn_param.with_sum) { + auto sum_index = GetInSumIndex(param); + orig_sum = base_in_types[sum_index]; + base_in_types[sum_index] = mshadow::kFloat32; + } + bool result = DefaultSubgraphOpType(attrs, &base_in_types, &base_out_types); + base_in_types[0] = orig_data; + if (param.full_conv_param.mkldnn_param.with_sum) { + auto sum_index = GetInSumIndex(param); + base_in_types[sum_index] = orig_sum; + } + size_t base_idx = 0; + for (size_t i = 0; i < in_types->size(); ++i) { + if (minmax_indice.count(i)) { + TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); + } else { + in_types->at(i) = base_in_types[base_idx++]; + } + } + if (param.full_conv_param.mkldnn_param.min_calib_range.has_value() && + param.full_conv_param.mkldnn_param.max_calib_range.has_value()) { + if (IsOutputUInt8(param.full_conv_param.mkldnn_param)) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8); + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); + } + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); + } + + TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); + return result; + } else { + return DefaultSubgraphOpType(attrs, in_types, out_types); + } +} + +static bool SgMKLDNNConvOpStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_stypes, + std::vector *out_stypes) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.full_conv_param.mkldnn_param.quantized) { + std::unordered_set minmax_indice; + std::vector base_in_stypes; + std::vector base_out_stypes; + FilterMinMaxIndice(param.full_conv_param.mkldnn_param, in_stypes, + out_stypes, &base_in_stypes, &base_out_stypes, + &minmax_indice); + bool result = DefaultSubgraphOpStorageType( + attrs, dev_mask, dispatch_mode, &base_in_stypes, &base_out_stypes); + size_t base_idx = 0; + for (size_t i = 0; i < in_stypes->size(); ++i) { + if (minmax_indice.count(i)) { + type_assign(&in_stypes->at(i), mxnet::kDefaultStorage); + } else { + in_stypes->at(i) = base_in_stypes[base_idx++]; + } + } + out_stypes->at(0) = base_out_stypes[0]; + type_assign(&out_stypes->at(1), mxnet::kDefaultStorage); + type_assign(&out_stypes->at(2), mxnet::kDefaultStorage); + return result; + } else { + return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + in_stypes, out_stypes); + } +} + +std::vector> SgMKLDNNConvInplaceOption( + const NodeAttrs &attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.full_conv_param.mkldnn_param.with_sum) { + return std::vector>{{GetInSumIndex(param), 0}}; + } else { + return std::vector>(); + } +} + +nnvm::NodePtr SgMKLDNNConvQuantizedOp(const NodeAttrs& attrs) { + nnvm::NodePtr node = nnvm::Node::Create(); + node->attrs.op = Op::Get("_sg_mkldnn_conv"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["quantized"] = "true"; + node->attrs.subgraphs.reserve(attrs.subgraphs.size()); + for (auto sub : attrs.subgraphs) { + node->attrs.subgraphs.push_back(sub); + } + node->op()->attr_parser(&(node->attrs)); + return node; +} + +bool SgMKLDNNAvoidQuantizeInput(const NodeAttrs &attrs, size_t index) { + auto const ¶m = nnvm::get(attrs.parsed); + std::unordered_set avoid_indice; + size_t idx = 0; + idx++; // data + avoid_indice.insert(idx++); // weight + if (!param.full_conv_param.conv_param.no_bias) { + avoid_indice.insert(idx++); // bias + } + if (param.full_conv_param.mkldnn_param.with_bn) { + avoid_indice.insert(idx++); // gamma + avoid_indice.insert(idx++); // beta + avoid_indice.insert(idx++); // mean + avoid_indice.insert(idx++); // var + } + return avoid_indice.count(index); +} + +NNVM_REGISTER_OP(_sg_mkldnn_conv) +.describe(R"code(_sg_mkldnn_conv)code" ADD_FILELINE) +.set_num_inputs(SgMKLDNNConvNumInputs) +.set_num_outputs([](const NodeAttrs& attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + return param.full_conv_param.mkldnn_param.quantized ? 3 : 1; +}) +.set_attr_parser(SgMKLDNNConvParamParser) +.set_attr("FListInputNames", SgMKLDNNConvListInputNames) +.set_attr("FListOutputNames", SgMKLDNNConvListOutputNames) +.set_attr("FCreateOpState", CreateSgMKLDNNConvState) +.set_attr("FInferShape", SgMKLDNNConvInferShape) +.set_attr("FInferType", SgMKLDNNConvInferType) +.set_attr("FInferStorageType", SgMKLDNNConvOpStorageType) +.set_attr("FStatefulComputeEx", SgMKLDNNConvOpForward) +.set_attr("TIsMKLDNN", true) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FMutateInputs", + DefaultSubgraphOpMutableInputs) +.set_attr("key_var_num_args", "num_args") +.set_attr("FInplaceOption", SgMKLDNNConvInplaceOption) +.set_attr("FQuantizedOp", SgMKLDNNConvQuantizedOp) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.set_attr("FAvoidQuantizeInput", SgMKLDNNAvoidQuantizeInput); + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc new file mode 100644 index 000000000000..fc68287b039d --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#if MXNET_USE_MKLDNN == 1 + +#include "../common.h" +#include "../subgraph_property.h" +#include "../../nn/mkldnn/mkldnn_convolution-inl.h" +#include "mkldnn_conv-inl.h" +#include "../../quantization/requantize-inl.h" + +namespace mxnet { +namespace op { + +class SgMKLDNNConvPostQuantizeSelector : public SubgraphSelector { + public: + /*! \brief pattern match status */ + enum SelectStatus { + kFail = 0, + kStart, + kSuccess, + }; + + private: + bool disable_all; + SelectStatus status; + std::vector matched_list; + + public: + explicit SgMKLDNNConvPostQuantizeSelector(int dis_all) + : disable_all(dis_all) {} + + bool Select(const nnvm::Node &n) override { + if ((!disable_all) && n.op() && n.op()->name == "_sg_mkldnn_conv") { + auto const ¶m = nnvm::get(n.attrs.parsed); + if (param.full_conv_param.mkldnn_param.quantized) { + status = kStart; + matched_list.clear(); + matched_list.push_back(&n); + return true; + } + } + return false; + } + + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } + + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { + if (status == kFail || status == kSuccess || new_node.is_variable()) + return false; + // If n isn't the last matched node, then we encoutered a internal + // branch, we should pop out the node behind n and stop fusion. + if (matched_list.back() != &n) { + status = kFail; + return false; + } + if (new_node.op()->name == "_contrib_requantize") { + auto const ¶m = nnvm::get(new_node.attrs.parsed); + if (param.min_calib_range.has_value() && + param.max_calib_range.has_value()) { + matched_list.push_back(&new_node); + status = kSuccess; + return true; + } else { + status = kFail; + } + } + return false; + } + + std::vector Filter( + const std::vector &candidates) override { + if (status != kSuccess) { + return std::vector(0); + } else { + return candidates; + } + } +}; + +class SgMKLDNNConvPostQuantizeProperty : public SubgraphProperty { + public: + SgMKLDNNConvPostQuantizeProperty() { + disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_OPT", 0); + if (disable_all) { + LOG(INFO) << "MKLDNN Convolution post-quantization optimization pass is disabled."; + } else { + LOG(INFO) << "Start to execute MKLDNN Convolution post-quantization optimization pass."; + } + } + static SubgraphPropertyPtr Create() { + return std::make_shared(); + } + nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::NodePtr conv_node = nullptr; + nnvm::NodePtr requantize_node = nullptr; + DFSVisit(sym.outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + auto &op_name = node->op()->name; + if (op_name == "_sg_mkldnn_conv") { + conv_node = node; + } else if (op_name == "_contrib_requantize") { + requantize_node = node; + } + }); + CHECK_NOTNULL(conv_node); + CHECK_NOTNULL(requantize_node); + auto const &requantize_param = + nnvm::get(requantize_node->attrs.parsed); + CHECK(requantize_param.min_calib_range.has_value()); + CHECK(requantize_param.max_calib_range.has_value()); + conv_node->attrs.dict["min_calib_range"] = + std::to_string(requantize_param.min_calib_range.value()); + conv_node->attrs.dict["max_calib_range"] = + std::to_string(requantize_param.max_calib_range.value()); + conv_node->op()->attr_parser(&(conv_node->attrs)); + return conv_node; + } + + SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = + std::make_shared(disable_all); + return selector; + } + + void ConnectSubgraphOutputs( + const nnvm::NodePtr n, + std::vector *output_entries) const override { + for (size_t i = 0; i < output_entries->size(); ++i) { + auto entry_ptr = output_entries->at(i); + *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; + } + } + + private: + int disable_all; +}; + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_QUANTIZE, SgMKLDNNConvPostQuantizeProperty); + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc new file mode 100644 index 000000000000..e5220f24d34d --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#if MXNET_USE_MKLDNN == 1 + +#include "../common.h" +#include "../subgraph_property.h" +#include "../../nn/activation-inl.h" + +namespace mxnet { +namespace op { +class SgMKLDNNConvSelector : public SubgraphSelector { + public: + /*! \brief pattern match status */ + enum SelectStatus { + kFail = 0, + kStart, + kBN, + kSum, + kSuccess, + }; + + private: + bool disable_all; + bool disable_conv_bn; + bool disable_conv_relu; + bool disable_conv_sum; + SelectStatus status; + std::vector matched_list; + + public: + SgMKLDNNConvSelector(int dis_all, int dis_conv_bn, int dis_conv_relu, int dis_conv_sum) + : disable_all(dis_all), + disable_conv_bn(dis_conv_bn), + disable_conv_relu(dis_conv_relu), + disable_conv_sum(dis_conv_sum) {} + + bool Select(const nnvm::Node &n) override { + if (n.op() && n.op()->name == "Convolution") { + status = disable_all ? kSuccess : kStart; + matched_list.clear(); + matched_list.push_back(&n); + return true; + } + return false; + } + + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } + + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { + if (status == kFail || status == kSuccess || new_node.is_variable()) + return false; + // If n isn't the last matched node, then we encoutered a internal + // branch, we should pop out the node behind n and stop fusion. + if (matched_list.back() != &n) { + while (matched_list.back() != &n) { + matched_list.pop_back(); + } + status = kSuccess; + return false; + } + // Use status machine to do selection. The status change is + // kStart -> kBN -> kSum -> kSuccess + switch (status) { + case kStart: + if ((!disable_conv_bn) && new_node.op()->name == "BatchNorm") { + matched_list.push_back(&new_node); + status = kBN; + return true; + } + case kBN: + if ((!disable_conv_sum) && new_node.op()->name == "elemwise_add") { + matched_list.push_back(&new_node); + status = kSum; + return true; + } + case kSum: + default: + if ((!disable_conv_relu) && new_node.op()->name == "Activation") { + const ActivationParam ¶m = + nnvm::get(new_node.attrs.parsed); + if (param.act_type == activation::kReLU) { + matched_list.push_back(&new_node); + // If we find conv+relu, then we can't match bn anymore. + if (status == kStart) status = kBN; + return true; + } else { + status = kSuccess; + return false; + } + } + status = kSuccess; + return false; + } + } + + std::vector Filter( + const std::vector &candidates) override { + if (status == kFail) { + return std::vector(0); + } else { + return candidates; + } + } +}; + +class SgMKLDNNConvProperty : public SubgraphProperty { + public: + SgMKLDNNConvProperty() { + disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_OPT", 0); + disable_conv_bn = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_BN", 0); + disable_conv_relu = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_RELU", 0); + disable_conv_sum = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_SUM", 0); + + disable_all = + disable_all && disable_conv_bn && disable_conv_relu && disable_conv_sum; + if (disable_all) { + LOG(INFO) << "MKLDNN Convolution optimization pass is disabled."; + } else { + LOG(INFO) << "Start to execute MKLDNN Convolution optimization pass."; + } + } + static SubgraphPropertyPtr Create() { + return std::make_shared(); + } + nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::NodePtr n = nnvm::Node::Create(); + // This op has single output, remove duplicated. + auto last_node = sym.outputs[0].node; + nnvm::Symbol new_sym; + new_sym.outputs.emplace_back(nnvm::NodeEntry{last_node, 0, 0}); + std::ostringstream node_name; + node_name << "sg_mkldnn_"; + bool _with_sum = false; + DFSVisit(new_sym.outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + auto &sub_name = node->op()->name; + if (sub_name == "Convolution") { + node_name << "conv_"; + } else if (sub_name == "BatchNorm") { + node_name << "bn_"; + n->attrs.dict["with_bn"] = "true"; + } else if (sub_name == "elemwise_add") { + node_name << "add_"; + n->attrs.dict["with_sum"] = "true"; + _with_sum = true; + + } else if (sub_name == "Activation") { + node_name << "relu_"; + if (!_with_sum) { + n->attrs.dict["with_relu"] = "true"; + } else { + n->attrs.dict["with_postsum_relu"] = "true"; + } + } + }); + node_name << std::to_string(subgraph_id); + n->attrs.name = node_name.str(); + n->attrs.op = Op::Get("_sg_mkldnn_conv"); + CHECK(n->attrs.op); + n->attrs.subgraphs.emplace_back(std::make_shared(new_sym)); + n->op()->attr_parser(&(n->attrs)); + return n; + } + + SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = std::make_shared( + disable_all, disable_conv_bn, disable_conv_relu, disable_conv_sum); + return selector; + } + + void ConnectSubgraphOutputs( + const nnvm::NodePtr n, + std::vector *output_entries) const override { + // Connect all extern output entries to output[0] + for (size_t i = 0; i < output_entries->size(); ++i) { + *output_entries->at(i) = nnvm::NodeEntry{n, 0, 0}; + } + } + + void ConnectSubgraphInputs( + const nnvm::NodePtr n, std::vector *input_entries, + std::vector *orig_input_entries) const override { + auto sym = n->attrs.subgraphs[0]; + std::unordered_set node_sets; + DFSVisit(sym->outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + node_sets.insert(node.get()); + if (node->op()->name == "elemwise_add") { + // Make sure n is the left operand of sum, if not, + // switch sum operands sequence to ensure that + // the extra sum operand stays in the last of inputs. + if (node_sets.count(node->inputs[1].node.get())) { + auto tmp = node->inputs[1]; + node->inputs[1] = node->inputs[0]; + node->inputs[0] = tmp; + std::rotate(input_entries->begin(), input_entries->begin() + 1, + input_entries->end()); + std::rotate(orig_input_entries->begin(), + orig_input_entries->begin() + 1, + orig_input_entries->end()); + } + } + }); + n->inputs = *orig_input_entries; + } + + private: + int disable_all; + int disable_conv_bn; + int disable_conv_relu; + int disable_conv_sum; +}; + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc index 315f7eec00c6..da9a9f375fa5 100644 --- a/src/operator/subgraph/partition_graph.cc +++ b/src/operator/subgraph/partition_graph.cc @@ -653,10 +653,9 @@ void CreateSubgraphNode(Graph* g, nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_id); // Connect the external nodes to the subgraph node. - for (size_t i = 0; i < output_entries.size(); ++i) { - *output_entries[i] = nnvm::NodeEntry{n, static_cast(i), 0}; - } - n->inputs = orig_input_entries; + subg_prop->ConnectSubgraphOutputs(n, &output_entries); + subg_prop->ConnectSubgraphInputs(n, &input_entries, &orig_input_entries); + const auto& indexed_graph = g->indexed_graph(); for (size_t i = 0; i < n->inputs.size(); ++i) { auto& e = n->inputs[i]; diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index cfbc1f837337..e9fdd6619275 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -62,16 +62,22 @@ class SubgraphSelector { * \brief Determines if to select input_node when traverse to the cur_node. * \param cur_node the node for determining whether its input_node should be selected * \param input_node the input node of the cur_node + * \return true if input_node is selected */ virtual bool SelectInput(const nnvm::Node &cur_node, const nnvm::Node &input_node) = 0; /*! * \brief Determines if to select output_node when traverse to the cur_node. * \param cur_node the node for determining whether its output_node should be selected * \param output_node the output node of the cur_node + * \return true if output_node is selected */ virtual bool SelectOutput(const nnvm::Node &cur_node, const nnvm::Node &output_node) = 0; - // Post processes pre-selected subgraph nodes. Return a list of nodes that - // users want to keep in subgraph(s). + /*! + * \brief Post processes pre-selected subgraph nodes. Return a list of nodes that + * users want to keep in subgraph(s). + * \param candidates re-selected subgraph nodes to filt + * \return a list of nodes to keep + */ virtual std::vector Filter(const std::vector& candidates) { return candidates; } @@ -81,30 +87,65 @@ using SubgraphSelectorPtr = std::shared_ptr; /*! * \brief This provides a set of properties for partitioning a graph into subgraphs, - * reconstructing a new graph from the subgraphs and creating a subgraph - * operator to execute the subgraph. + * reconstructing a new graph from the subgraphs and creating a subgraph + * operator to execute the subgraph. */ class SubgraphProperty { public: - // the criteria of selecting the subgraph nodes. + /*! + * \brief The criteria of selecting the subgraph nodes. + */ virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0; - // create an nnvm node for a given subgraph. Here users can customize how to - // execute the operators in the subgraph. - virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s, + /*! + * \brief Create an nnvm node for a given subgraph. Here users can customize how to + * execute the operators in the subgraph. + * \param sym the symbol to create subgraph node + * \param subgraph_id subgraph id + */ + virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, const int subgraph_id = 0) const = 0; - // set an attr with name in the attr map + /*! + * \brief Connect subgraph internal output with external output entries. + * By default, each output entry will connect to an unique internal output. + * \param subgraph_node the subgraph node to connect output + * \param output_entries external output entries depending on this subgraph node + */ + virtual void ConnectSubgraphOutputs(const nnvm::NodePtr subgraph_node, + std::vector* output_entries) const { + for (size_t i = 0; i < output_entries->size(); ++i) { + *output_entries->at(i) = nnvm::NodeEntry{subgraph_node, static_cast(i), 0}; + } + } + /*! + * \brief Connect subgraph internal input with external input entries. + * By default, each input entry will connect in top sorted order. + * \param subgraph_node the subgraph node to connect input + * \param input_entries input entries inside subgraph + * \param orig_input_entries input entries outside subgraph + */ + virtual void ConnectSubgraphInputs(const nnvm::NodePtr subgraph_node, + std::vector* input_entries, + std::vector* orig_input_entries) const { + subgraph_node->inputs = *orig_input_entries; + } + /*! + * \brief Set an attr with name in the attr map. + */ template SubgraphProperty& SetAttr(const std::string& name, const T& value) { attrs_[name] = std::make_shared(value); return *this; } - // get the attr with the name + /*! + * \brief Get the attr with the name. + */ template const T& GetAttr(const std::string& name) const { auto it = attrs_.find(name); CHECK(it != attrs_.end()) << "Cannot find attribute " << name << " in SubgraphProperty"; return nnvm::get(*it->second); } + protected: std::unordered_map> attrs_; }; diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py new file mode 100644 index 000000000000..5b708216e2ac --- /dev/null +++ b/tests/python/mkl/test_subgraph.py @@ -0,0 +1,487 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +import os +import mxnet as mx +import numpy as np +import unittest +import ctypes +from mxnet.io import NDArrayIter +from mxnet.module import Module +from mxnet.symbol import Symbol +from importlib import import_module +from numpy.testing import assert_allclose +from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str +from mxnet.test_utils import DummyIter +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, '../unittest/')) +from common import with_seed +from mxnet.test_utils import assert_almost_equal + +DATA_SHAPE=[(4, 4, 10, 10), (32, 3, 24, 24), (64, 8, 64, 64)] + +def check_qsym_calibrated(qsym): + assert ''.join(qsym.attr_dict().keys()).find('quantized_sg_mkldnn_conv') != -1 + for k, v in qsym.attr_dict().items(): + if k.find('quantized_sg_mkldnn_conv') != -1: + assert 'min_calib_range' in v + assert 'max_calib_range' in v + if k.find('_quantize') != -1: + assert v['out_type'] == 'uint8' + +def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape): + mod = mx.mod.Module(symbol=qsym, context=mx.current_context()) + mod.bind(for_training=False, + data_shapes=[('data', data_shape)], + label_shapes=[('softmax_label', label_shape)]) + mod.set_params(qarg_params, qaux_params) + mod.forward(batch, is_train=False) + for output in mod.get_outputs(): + output.wait_to_read() + return mod.get_outputs() + +def check_quantize(sym, data_shape): + fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') + sym_sg = sym.get_backend_symbol("MKLDNN") + label_shape = (data_shape[0], 10) + mod = Module(symbol=sym) + mod.bind(for_training=False, + data_shapes=[('data', data_shape)], + label_shapes=[('softmax_label', label_shape)]) + mod.init_params(mx.init.Normal(0.5)) + arg_params, aux_params = mod.get_params() + + data = [mx.random.uniform(-1, 1, shape=shape, ctx=mx.current_context()) for _, shape in mod.data_shapes] + batch = mx.io.DataBatch(data, []) + + mod.forward(batch, is_train=False) + for output in mod.get_outputs(): + output.wait_to_read() + ref_out = mod.get_outputs() + + excluded_sym_names = [] + if mx.current_context() == mx.cpu(): + excluded_sym_names += ['fc'] + + calib_data = mx.nd.random.uniform(shape=data_shape) + calib_data = NDArrayIter(data=calib_data) + calib_data = DummyIter(calib_data) + calib_layer = lambda name: name.endswith('_output') + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, + arg_params=arg_params, + aux_params=aux_params, + ctx=mx.current_context(), + excluded_sym_names=excluded_sym_names, + quantized_dtype='uint8', + calib_mode='naive', + calib_data=calib_data, + calib_layer=calib_layer, + calib_quantize_op=True, + num_calib_examples=5) + qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE") + check_qsym_calibrated(qsym) + quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape) + for i in range(len(ref_out)): + assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1) + + +@with_seed() +def check_fusion(sym, data_shape, attrs_op): + sym_sg = sym.get_backend_symbol("MKLDNN") + assert ''.join(sym_sg.get_internals().list_outputs()).find('sg_mkldnn_conv') != -1 + for k, v in sym_sg.attr_dict().items(): + if k.find('sg_mkldnn_conv') != -1: + for attr_op in attrs_op: + assert v[attr_op] == 'true' + + arg_shapes, _, aux_shapes = sym.infer_shape() + arg_array = [mx.nd.random.uniform(-1, 1, shape=shape) for shape in arg_shapes] + aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes] + exe = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') + exe.forward() + os.environ['MXNET_SUBGRAPH_BACKEND'] = 'MKLDNN' + exe_sg = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') + exe_sg.forward() + del os.environ['MXNET_SUBGRAPH_BACKEND'] + for i in range(len(exe.outputs)): + assert_almost_equal(exe.outputs[i].asnumpy(), exe_sg.outputs[i].asnumpy(), rtol=1e-3, atol=1e-3) + + # fp32 to uint8 + check_quantize(sym, data_shape) + +def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, date_shape=(4,4,10,10)): + for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs): + sym_sg = sym.get_backend_symbol("MKLDNN") + exe_sg = sym_sg.simple_bind(mx.cpu(), data=date_shape, grad_req='null') + + attrs_dict = sym_sg.attr_dict() + for k, v in attrs_dict.items(): + if k.find('sg_mkldnn_conv') != -1: + for attr in attrs: + assert v[attr] == 'true' + for exc_attr in excluded_attr: + assert exc_attr not in v.keys() + +def head_symbol(data_shape): + data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') + weight = mx.symbol.Variable('weight', dtype='float32') + bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn') + return bn, weight + +# single conv fuision case +def single_conv(no_bias, data_shape): + conv_attr = [''] + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + return conv, conv_attr + +# conv + bn fusion case +def conv_bn(no_bias, data_shape): + conv_bn_attr = ['with_bn'] + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") + return bn1, conv_bn_attr + +# conv + relu fusion case +def conv_relu(no_bias, data_shape): + conv_relu_attr = ['with_relu'] + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + relu = mx.symbol.Activation(data=conv, name='relu', act_type="relu") + return relu, conv_relu_attr + +# conv + add fusion case +def conv_add(no_bias, data_shape): + conv_add_attr = ['with_sum'] + data, weight = head_symbol(data_shape) + conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + conv2 = mx.symbol.Convolution(data=data, name='conv2', num_filter=64, + kernel=(3, 3), stride=(1, 1)) + pool = mx.sym.Pooling(data=conv2, kernel=(1, 1), pool_type='avg', name='pool') + sum = conv1 + pool + return sum, conv_add_attr + +# conv + add fusion case 2 +def conv_add2(no_bias, data_shape): + conv_add_attr = ['with_sum'] + data, weight = head_symbol(data_shape) + conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + conv2 = mx.symbol.Convolution(data=data, name='conv2', num_filter=64, + kernel=(3, 3), stride=(1, 1)) + pool = mx.sym.Pooling(data=conv2, kernel=(1, 1), pool_type='avg', name='pool') + sum = pool + conv1 + return sum, conv_add_attr + +# conv + bn + relu fusion case +def conv_bn_relu(no_bias, data_shape): + conv_bn_relu_attr = ['with_bn', 'with_relu'] + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") + relu = mx.symbol.Activation(data=bn1, name='relu', act_type="relu") + return relu, conv_bn_relu_attr + +# conv + bn + add + relu fusion case +def conv_bn_sum_relu(no_bias, data_shape): + conv_bn_add_relu_attr = ['with_sum', 'with_postsum_relu', 'with_bn'] + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") + conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, + kernel=(3, 3), stride=(1, 1)) + sum1 = bn1 + conv1 + relu = mx.symbol.Activation(data=sum1, name='relu', act_type="relu") + return relu, conv_bn_add_relu_attr + +def tail_neg_symbol(sym1, sym2): + fc1 = mx.sym.FullyConnected(data=sym1, num_hidden=10, flatten=True, name='fc1') + fc2 = mx.sym.FullyConnected(data=sym2, num_hidden=10, flatten=True, name='fc2') + concat = mx.sym.Concat(*[fc1, fc2], name="concat") + sym = mx.sym.SoftmaxOutput(data=concat, name='softmax') + return sym + +# conv + bn can't be fusion case +# eg.1 +# conv --------- > bn +# | +# | +# -------------> [custom op] +def neg_conv_bn(data_shape): + syms = [] + attrs = [] + excluded_attrs = [] + data, weight = head_symbol(data_shape) + + # eg.1 ([custom op] = pool) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") + pool = mx.sym.Pooling(data=conv, kernel=(4, 4), pool_type='avg', name='pool') + sym = tail_neg_symbol(bn1, pool) + + syms.append(sym) + attrs.append([]) + excluded_attrs.append([]) + return syms, attrs, excluded_attrs + +# conv + relu can't be fusion case +# eg.1 +# conv -----------> relu +# | +# | +# ---------------> [custom op] +def neg_conv_relu(data_shape): + syms = [] + attrs = [] + excluded_attrs = [] + data, weight = head_symbol(data_shape) + + # eg.1 ([custom op] = pool) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + relu = mx.symbol.Activation(data=conv, name='relu', act_type="relu") + pool = mx.sym.Pooling(data=conv, kernel=(4, 4), pool_type='avg', name='pool') + sym = tail_neg_symbol(relu, pool) + + syms.append(sym) + attrs.append([]) + excluded_attrs.append([]) + return syms, attrs, excluded_attrs + +# conv + add can't be fusion case +# eg.1 +# ---------------> [custom op] +# | +# | +# conv -----------> add +# | +# | +# added ------------> +def neg_conv_add(data_shape): + syms = [] + attrs = [] + excluded_attrs = [] + val = mx.symbol.Variable('addval') + data, weight = head_symbol(data_shape) + + # eg.1 ([custom op] = pool, [added op] = val) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + sum1 = conv + val + pool = mx.sym.Pooling(data=conv, kernel=(4, 4), pool_type='avg', name='pool') + sym = tail_neg_symbol(sum1, pool) + + syms.append(sym) + attrs.append([]) + excluded_attrs.append('with_sum') + return syms, attrs, excluded_attrs + +# conv + bn + relu can't be fusion case +# eg.1 +# --------------> [custom op] +# | +# conv -----------> bn -----------> relu +# +# eg.2 +# --------------> [custom op] +# | +# conv -----------> bn -----------> relu +def neg_conv_bn_relu(data_shape): + syms = [] + attrs = [] + excluded_attrs = [] + data, weight = head_symbol(data_shape) + + # eg.1 ([custom op] = pool11) + conv11 = mx.symbol.Convolution(data=data, weight=weight, name='conv11', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn11 = mx.symbol.BatchNorm(data=conv11, name="bn11") + relu11 = mx.symbol.Activation(data=bn11, name='relu11', act_type="relu") + pool11 = mx.sym.Pooling(data=conv11, kernel=(4, 4), pool_type='avg', name='pool11') + sym1 = tail_neg_symbol(relu11, pool11) + + syms.append(sym1) + attrs.append([]) + excluded_attrs.append([]) + + # eg.2 ([custom op] = pool) + conv21 = mx.symbol.Convolution(data=data, weight=weight, name='conv21', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn21 = mx.symbol.BatchNorm(data=conv21, name="bn21") + relu21 = mx.symbol.Activation(data=bn21, name='relu21', act_type="relu") + pool21 = mx.sym.Pooling(data=bn21, kernel=(4, 4), pool_type='avg', name='pool21') + sym2 = tail_neg_symbol(relu21, pool21) + + syms.append(sym2) + attrs.append(['with_bn']) + excluded_attrs.append(['with_relu']) + return syms, attrs, excluded_attrs + +# conv + bn + add + relu can't be fusion case +# eg.1 +# --------------> [custom op] +# | +# conv -----------> bn -----------> add -----------> relu +# +# eg.2 +# -------------> [custom op] +# | +# conv -----------> bn -----------> add -----------> relu +# +# eg.3 +# --------------> [custom op] +# | +# conv -----------> bn -----------> add -----------> relu +def neg_conv_bn_add_relu(data_shape): + syms = [] + attrs = [] + excluded_attrs = [] + addVal = mx.symbol.Variable('addval') + data, weight = head_symbol(data_shape) + + # eg.1 + conv11 = mx.symbol.Convolution(data=data, weight=weight, name='conv11', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn11 = mx.symbol.BatchNorm(data=conv11, name="bn11") + sum11 = bn11 + addVal + relu11 = mx.symbol.Activation(data=sum11, name='relu11', act_type="relu") + pool11 = mx.sym.Pooling(data=conv11, kernel=(4, 4), pool_type='avg', name='pool11') + sym1 = tail_neg_symbol(relu11, pool11) + + syms.append(sym1) + attrs.append([]) + excluded_attrs.append(['with_sum', 'with_postsum_relu', 'with_bn']) + + # eg.2 + conv21 = mx.symbol.Convolution(data=data, weight=weight, name='conv21', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn21 = mx.symbol.BatchNorm(data=conv21, name="bn21") + sum21 = bn21 + addVal + relu21 = mx.symbol.Activation(data=sum21, name='relu21', act_type="relu") + pool21 = mx.sym.Pooling(data=bn21, kernel=(4, 4), pool_type='avg', name='pool21') + sym2 = tail_neg_symbol(relu21, pool21) + + syms.append(sym2) + attrs.append(['with_bn']) + excluded_attrs.append(['with_sum', 'with_postsum_relu']) + + # eg.3 + conv31 = mx.symbol.Convolution(data=data, weight=weight, name='conv31', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn31 = mx.symbol.BatchNorm(data=conv31, name="bn31") + sum31 = bn31 + addVal + relu31 = mx.symbol.Activation(data=sum31, name='relu31', act_type="relu") + pool31 = mx.sym.Pooling(data=sum31, kernel=(4, 4), pool_type='avg', name='pool31') + sym3 = tail_neg_symbol(relu31, pool31) + + syms.append(sym3) + attrs.append(['with_bn', 'with_sum']) + excluded_attrs.append(['with_postsum_relu']) + return syms, attrs, excluded_attrs + +@with_seed() +def test_pos_single_conv(): + for data_shape in DATA_SHAPE: + net, attrs = single_conv(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = single_conv(True, data_shape) + check_fusion(net, data_shape, attrs) + +@with_seed() +def test_pos_conv_relu(): + for data_shape in DATA_SHAPE: + net, attrs = conv_relu(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_relu(True, data_shape) + check_fusion(net, data_shape, attrs) + +@with_seed() +def test_pos_conv_bn(): + for data_shape in DATA_SHAPE: + net, attrs = conv_bn(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_bn(True, data_shape) + check_fusion(net, data_shape, attrs) + +@with_seed() +def test_pos_conv_add(): + for data_shape in DATA_SHAPE: + net, attrs = conv_add(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_add(True, data_shape) + check_fusion(net, data_shape, attrs) + +@with_seed() +def test_pos_conv_add2(): + for data_shape in DATA_SHAPE: + net, attrs = conv_add2(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_add2(True, data_shape) + check_fusion(net, data_shape, attrs) + +@with_seed() +def test_pos_conv_bn_relu(): + for data_shape in DATA_SHAPE: + net, attrs = conv_bn_relu(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_bn_relu(True, data_shape) + check_fusion(net, data_shape, attrs) + +@with_seed() +def test_pos_conv_bn_sum_relu(): + for data_shape in DATA_SHAPE: + net, attrs = conv_bn_sum_relu(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_bn_sum_relu(True, data_shape) + check_fusion(net, data_shape, attrs) + +@with_seed() +def test_neg_conv_bn(): + for data_shape in DATA_SHAPE: + syms, attrs, excluded_attrs = neg_conv_bn(data_shape) + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + +@with_seed() +def test_neg_conv_relu(): + for data_shape in DATA_SHAPE: + syms, attrs, excluded_attrs = neg_conv_relu(data_shape) + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + +@with_seed() +def test_neg_conv_add(): + for data_shape in DATA_SHAPE: + syms, attrs, excluded_attrs = neg_conv_add(data_shape) + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + +@with_seed() +def test_neg_conv_bn_relu(): + for data_shape in DATA_SHAPE: + syms, attrs, excluded_attrs = neg_conv_bn_relu(data_shape) + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + +@with_seed() +def test_neg_conv_bn_add_relu(): + for data_shape in DATA_SHAPE: + syms, attrs, excluded_attrs = neg_conv_bn_add_relu(data_shape) + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + + +if __name__ == "__main__": + import nose + nose.runmodule() diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 369a923c1879..5ae2c6c398e9 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -374,7 +374,7 @@ def test_quantize_params(): for name in offline_params: params[name] = mx.nd.uniform(shape=(2, 2)) qsym = mx.contrib.quant._quantize_symbol(sym, offline_params=offline_params) - qparams = mx.contrib.quant._quantize_params(qsym, params) + qparams = mx.contrib.quant._quantize_params(qsym, params, th_dict = {}) param_names = params.keys() qparam_names = qparams.keys() for name in qparam_names: @@ -406,7 +406,7 @@ def get_fp32_residual(): fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, name='fc') sym = mx.sym.SoftmaxOutput(fc, grad_scale=1, ignore_label=-1, multi_output=False, out_grad=False, preserve_shape=False, use_ignore=False, name='softmax') - return sym + return sym @with_seed() def test_quantize_model(): @@ -418,7 +418,7 @@ def check_params(params, qparams, qsym=None): assert k in qparams assert same(v.asnumpy(), qparams[k].asnumpy()) else: - qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params) + qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params, th_dict = {}) assert len(qparams) == len(qparams_ground_truth) for k, v in qparams_ground_truth.items(): assert k in qparams @@ -494,7 +494,7 @@ def check_params(params, qparams, qsym=None): assert k in qparams assert same(v.asnumpy(), qparams[k].asnumpy()) else: - qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params) + qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params, th_dict = {}) assert len(qparams) == len(qparams_ground_truth) for k, v in qparams_ground_truth.items(): assert k in qparams @@ -525,7 +525,7 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape): mod.forward(batch, is_train=False) for output in mod.get_outputs(): output.wait_to_read() - + sym = get_fp32_residual() mod = Module(symbol=sym)