diff --git a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h index 980b931851f3..1f321e710d59 100644 --- a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h +++ b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h @@ -28,6 +28,8 @@ #if MXNET_USE_ONEDNN == 1 +#include +#include #include #include @@ -41,6 +43,8 @@ struct DNNLFCParam : public dmlc::Parameter { bool quantized; bool enable_float_output; bool with_eltwise; + bool with_sum; + bool first_quantization_pass; // True for operator created during first quantization pass dmlc::optional min_calib_range; // min float value calculated from calibration dataset dmlc::optional max_calib_range; // max float value calculated from calibration dataset dmlc::optional channel_wise_quantize; @@ -54,6 +58,10 @@ struct DNNLFCParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(with_eltwise) .set_default(false) .describe("Whether there's a post with_eltwise after FullyConnected operator"); + DMLC_DECLARE_FIELD(with_sum).set_default(false).describe("Add post sum"); + DMLC_DECLARE_FIELD(first_quantization_pass) + .set_default(false) + .describe("True for first quantization pass"); DMLC_DECLARE_FIELD(min_calib_range) .set_default(dmlc::optional()) .describe( @@ -76,9 +84,86 @@ struct DNNLFCFullParam { FullyConnectedParam default_param; DNNLFCParam dnnl_param; DNNLPostEltwiseParam eltwise_param; + float sum_scale = {1.0f}; std::vector output_scales = {0.0f}; }; +static inline size_t GetInSumIndex(const DNNLFCFullParam& param) { + assert(param.dnnl_param.with_sum); + return fullc::kWeight + 1 + (param.default_param.no_bias ? 0 : 1); +} + +class FCInputIndex { + public: + explicit FCInputIndex(const DNNLFCFullParam full_param) { + auto& dnnl_param = full_param.dnnl_param; + const bool has_bias = !full_param.default_param.no_bias; + const bool quantized = dnnl_param.quantized; + const bool sum_input_quantized = + quantized && dnnl_param.with_sum && !dnnl_param.enable_float_output; + const bool channel_wise = quantized && dnnl_param.channel_wise_quantize.has_value() && + dnnl_param.channel_wise_quantize.value(); + + // Calculate position of particular input in the input vector: + int index = 0; + data = index++; + weight = index++; + bias = has_bias ? index++ : 0; + sum = dnnl_param.with_sum ? index++ : 0; + num_base = index; // note number of base inputs + + data_min = quantized ? index++ : 0; + data_max = quantized ? index++ : 0; + weight_min = (quantized && !channel_wise) ? index++ : 0; + weight_max = (quantized && !channel_wise) ? index++ : 0; + bias_min = (quantized && !channel_wise && has_bias) ? index++ : 0; + bias_max = (quantized && !channel_wise && has_bias) ? index++ : 0; + sum_min = sum_input_quantized ? index++ : 0; + sum_max = sum_input_quantized ? index++ : 0; + num_total = index; // note number of total inputs + } + + // Returns true if sum input exists + bool IsSumExist() const { + return sum; + } + + // Returns true if bias input exists + bool IsBiasExist() const { + return bias; + } + + // Returns true if sum input exists and it is float number + bool IsSumInputFloat() const { + return (sum && !sum_min); + } + int GetTotal() const { + return num_total; + } + int GetBase() const { + return num_base; + } + + // Represent index of particular input in the input vector: + int data; + int weight; + int bias; + int sum; + int data_min; + int data_max; + int weight_min; + int weight_max; + int bias_min; + int bias_max; + int sum_min; + int sum_max; + + private: + int num_base; // Number of standard inputs + int num_total; // Number of total inputs: standard + additional needed for + // quantization +}; + dnnl::inner_product_forward::primitive_desc GetFCFwdImpl(const DNNLFCFullParam& full_param, const bool is_train, const NDArray& data, diff --git a/src/operator/nn/dnnl/dnnl_fully_connected.cc b/src/operator/nn/dnnl/dnnl_fully_connected.cc index eca90b7cf4c6..6f04b1923b1a 100644 --- a/src/operator/nn/dnnl/dnnl_fully_connected.cc +++ b/src/operator/nn/dnnl/dnnl_fully_connected.cc @@ -53,6 +53,9 @@ dnnl::inner_product_forward::primitive_desc GetFCFwdImpl(const DNNLFCFullParam& full_param.eltwise_param.alpha, full_param.eltwise_param.beta); } + if (full_param.dnnl_param.with_sum) { + ops.append_sum(full_param.sum_scale); + } attr.set_post_ops(ops); if (full_param.dnnl_param.quantized && full_param.output_scales.size()) { diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc index ef1218b49df0..4acaa22cdc1b 100644 --- a/src/operator/subgraph/build_subgraph.cc +++ b/src/operator/subgraph/build_subgraph.cc @@ -749,6 +749,17 @@ void CreateSubgraphNode(nnvm::Graph* g, for (BiDirectedNode* dest_node : subgraph_nodes) { sn->outputs.erase(dest_node->node); } + } + } + + // Set outputs according to current inputs + for (size_t i = 0; i < n->inputs.size(); ++i) { + auto& e = n->inputs[i]; + // update input entries' source simple nodes' outputs map + nnvm::Node* node = e.node.get(); + if (indexed_graph.exist(node)) { + const auto nid = indexed_graph.node_id(node); + BiDirectedNode* sn = simple_nodes[nid].get(); sn->outputs[n.get()].push_back(i); } } diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc b/src/operator/subgraph/dnnl/dnnl_fc.cc index 51989cad3595..22c115214e41 100644 --- a/src/operator/subgraph/dnnl/dnnl_fc.cc +++ b/src/operator/subgraph/dnnl/dnnl_fc.cc @@ -25,7 +25,10 @@ #if MXNET_USE_ONEDNN == 1 +#include #include +#include +#include #include #include @@ -62,8 +65,8 @@ class SgDNNLFCOp { private: bool initialized_{false}; - bool channel_wise_runtime_{false}; bool reorder_data_{false}; + bool inplace_{false}; nnvm::Symbol subgraph_sym_; DNNLFCFullParam full_param_; dnnl_args_map_t args_; @@ -72,83 +75,123 @@ class SgDNNLFCOp { std::shared_ptr cached_out_mem_; NDArray cached_weight_; NDArray cached_bias_; - float cached_min_data_; - float cached_max_data_; - float cached_min_weight_; - float cached_max_weight_; - float cached_min_bias_; - float cached_max_bias_; + float cached_data_min_; + float cached_data_max_; + float cached_weight_min_; + float cached_weight_max_; + float cached_sum_min_; + float cached_sum_max_; + float cached_bias_min_; + float cached_bias_max_; size_t weight_ver_; size_t bias_ver_; - float cached_min_output_; - float cached_max_output_; + float cached_output_min_; + float cached_output_max_; float data_scale_{0.0f}; std::vector weight_scales_; - size_t total_num_inputs_; - size_t total_num_outputs_; }; void SgDNNLFCOp::Forward(const OpContext& ctx, const std::vector& in_data, const std::vector& req, const std::vector& out_data) { - auto& dnnl_param = full_param_.dnnl_param; - auto& default_param = full_param_.default_param; - bool has_bias = !default_param.no_bias; - size_t base_num_inputs = has_bias ? 3 : 2; - size_t base_num_outputs = 1; - - float min_data = 0.0f; - float max_data = 0.0f; - float min_weight = 0.0f; - float max_weight = 0.0f; - float min_bias = 0.0f; - float max_bias = 0.0f; - - if (!initialized_) { - if (dnnl_param.channel_wise_quantize.has_value() && dnnl_param.channel_wise_quantize) { - channel_wise_runtime_ = true; + auto& dnnl_param = full_param_.dnnl_param; + auto& default_param = full_param_.default_param; + const bool has_bias = !default_param.no_bias; + const bool quantized = dnnl_param.quantized; + const bool out_quantized = dnnl_param.quantized && !dnnl_param.enable_float_output; + const bool channel_wise = quantized && dnnl_param.channel_wise_quantize.has_value() && + dnnl_param.channel_wise_quantize.value(); + + const FCInputIndex idx(full_param_); + + CHECK_EQ(in_data.size(), idx.GetTotal()); + + int index = 0; + const int out_index = index++; + const int out_min_index = out_quantized ? index++ : 0; + const int out_max_index = out_quantized ? index++ : 0; + CHECK_EQ(out_data.size(), index); // index is equal to total number of outputs + + float data_min = 0.0f; + float data_max = 0.0f; + float weight_min = 0.0f; + float weight_max = 0.0f; + float bias_min = 0.0f; + float bias_max = 0.0f; + + const float sum_min = idx.sum_min ? in_data[idx.sum_min].data().dptr()[0] : 0.0; + const float sum_max = idx.sum_max ? in_data[idx.sum_max].data().dptr()[0] : 0.0; + NDArray data = in_data[idx.data]; + const NDArray& weight = in_data[idx.weight]; + NDArray output; + + if (dnnl_param.with_sum) { + if (!initialized_) { + // TODO(zhennan): Currently, dnnl fallback mechanism will break inplace option, + // which make check (req[out_index] == kWriteInplace) useless. + auto in_dnnl_mem = static_cast(in_data[idx.sum].GetDNNLData()); + auto out_dnnl_mem = static_cast(out_data[out_index].GetDNNLData()); + if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle()) { + inplace_ = true; + } } - - total_num_inputs_ = base_num_inputs; - total_num_outputs_ = base_num_outputs; - if (dnnl_param.quantized) { - total_num_inputs_ = channel_wise_runtime_ ? (base_num_inputs + 2) : (base_num_inputs * 3); - total_num_outputs_ = - dnnl_param.enable_float_output ? base_num_outputs : (base_num_outputs * 3); + if (inplace_) { + output = in_data[idx.sum]; + } else { + // Not in place: copy in_data[idx.sum] into outputs[out_index]. + auto in_dnnl_mem = static_cast(in_data[idx.sum].GetDNNLData()); + auto out_dnnl_mem = static_cast(out_data[out_index].GetDNNLData()); + if (out_data[out_index].dtype() == mshadow::kInt32) { + auto mem_desc = in_dnnl_mem->get_desc(); + auto this_dtype = get_dnnl_type(mshadow::kInt32); + mem_desc.data.data_type = static_cast(this_dtype); + dnnl_mem_ptr tmp_mem(new dnnl::memory( + mem_desc, CpuEngine::Get()->get_engine(), out_dnnl_mem->get_data_handle())); + DNNLStream::Get()->RegisterMem(tmp_mem); + DNNLStream::Get()->RegisterPrimArgs( + dnnl::reorder(*in_dnnl_mem, *tmp_mem), + {{DNNL_ARG_FROM, *in_dnnl_mem}, {DNNL_ARG_TO, *tmp_mem}}); + output = NDArray(tmp_mem); + } else { + dnnl_mem_ptr tmp_mem(new dnnl::memory(in_dnnl_mem->get_desc(), + CpuEngine::Get()->get_engine(), + out_dnnl_mem->get_data_handle())); + DNNLStream::Get()->RegisterMem(tmp_mem); + DNNLMemoryCopy(*in_dnnl_mem, tmp_mem.get()); + output = NDArray(tmp_mem); + } } + } else { + output = out_data[out_index]; } - CHECK_EQ(in_data.size(), total_num_inputs_); - CHECK_EQ(out_data.size(), total_num_outputs_); - - NDArray data = in_data[fullc::kData]; - const NDArray& weight = in_data[fullc::kWeight]; - const NDArray& output = out_data[fullc::kOut]; if (dnnl_param.quantized) { - if (!channel_wise_runtime_) { - min_weight = in_data[base_num_inputs + quantized_fullc::kWeightMin].data().dptr()[0]; - max_weight = in_data[base_num_inputs + quantized_fullc::kWeightMax].data().dptr()[0]; + if (!channel_wise) { + weight_min = in_data[idx.weight_min].data().dptr()[0]; + weight_max = in_data[idx.weight_max].data().dptr()[0]; if (has_bias) { - min_bias = in_data[base_num_inputs + quantized_fullc::kBiasMin].data().dptr()[0]; - max_bias = in_data[base_num_inputs + quantized_fullc::kBiasMax].data().dptr()[0]; + bias_min = in_data[idx.bias_min].data().dptr()[0]; + bias_max = in_data[idx.bias_max].data().dptr()[0]; } } - min_data = in_data[base_num_inputs + quantized_fullc::kDataMin].data().dptr()[0]; - max_data = in_data[base_num_inputs + quantized_fullc::kDataMax].data().dptr()[0]; + data_min = in_data[idx.data_min].data().dptr()[0]; + data_max = in_data[idx.data_max].data().dptr()[0]; } if (initialized_ && dnnl_param.quantized && dmlc::GetEnv("MXNET_ONEDNN_QFC_DYNAMIC_PARAMS", 0)) { - if (channel_wise_runtime_) { - if (cached_min_data_ != min_data || cached_max_data_ != max_data || + if (channel_wise) { + if (cached_data_min_ != data_min || cached_data_max_ != data_max || + cached_sum_min_ != sum_min || cached_sum_max_ != sum_max || weight_ver_ != weight.version() || - (has_bias && (bias_ver_ != in_data[fullc::kBias].version()))) { + (has_bias && (bias_ver_ != in_data[idx.bias].version()))) { initialized_ = false; } } else { - if (cached_min_data_ != min_data || cached_max_data_ != max_data || - cached_min_weight_ != min_weight || cached_max_weight_ != max_weight || - (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != max_bias))) { + if (cached_data_min_ != data_min || cached_data_max_ != data_max || + cached_sum_min_ != sum_min || cached_sum_max_ != sum_max || + cached_weight_min_ != weight_min || cached_weight_max_ != weight_max || + (has_bias && (cached_bias_min_ != bias_min || cached_bias_max_ != bias_max))) { initialized_ = false; } } @@ -157,17 +200,19 @@ void SgDNNLFCOp::Forward(const OpContext& ctx, if (!initialized_) { const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); const auto engine = CpuEngine::Get()->get_engine(); - cached_min_data_ = min_data; - cached_max_data_ = max_data; - cached_min_weight_ = min_weight; - cached_max_weight_ = max_weight; + cached_data_min_ = data_min; + cached_data_max_ = data_max; + cached_weight_min_ = weight_min; + cached_weight_max_ = weight_max; weight_ver_ = weight.version(); cached_weight_ = weight; + cached_sum_min_ = sum_min; + cached_sum_max_ = sum_max; if (has_bias) { - cached_min_bias_ = min_bias; - cached_max_bias_ = max_bias; - bias_ver_ = in_data[fullc::kBias].version(); - cached_bias_ = in_data[fullc::kBias]; + cached_bias_min_ = bias_min; + cached_bias_max_ = bias_max; + bias_ver_ = in_data[idx.bias].version(); + cached_bias_ = in_data[idx.bias]; } else { cached_bias_ = NDArray(); } @@ -210,13 +255,13 @@ void SgDNNLFCOp::Forward(const OpContext& ctx, bool support_channelwise_scale = false; if (dnnl_param.quantized) { CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8); - data_scale_ = GetQuantizeScale(data.dtype(), cached_min_data_, cached_max_data_); + data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_, cached_data_max_); bool fuse_requantize = false; // Channelwise scaling is only supported when fusion is enabled (requantize or dequantize). if (dnnl_param.min_calib_range.has_value() && dnnl_param.max_calib_range.has_value()) { - cached_min_output_ = dnnl_param.min_calib_range.value(); - cached_max_output_ = dnnl_param.max_calib_range.value(); + cached_output_min_ = dnnl_param.min_calib_range.value(); + cached_output_max_ = dnnl_param.max_calib_range.value(); support_channelwise_scale = true; fuse_requantize = true; } @@ -227,7 +272,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx, // True True True // True False Error // False True/False False - if (channel_wise_runtime_ && !support_channelwise_scale) { + if (channel_wise && !support_channelwise_scale) { LOG(FATAL) << "Currently, channel-wise quantization requires fuse requantize or dequantize." << " Please make sure the `min_calib_range` and `max_calib_range` are set when only" @@ -236,7 +281,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx, << " or the env var of `MXNET_DISABLE_ONEDNN_QFC_FLOAT_OUTPUT` and " << " `MXNET_DISABLE_ONEDNN_QFC_FUSE_ALL` are not set to true (default is false)"; } - support_channelwise_scale = support_channelwise_scale && channel_wise_runtime_; + support_channelwise_scale = support_channelwise_scale && channel_wise; if (support_channelwise_scale) { MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, { @@ -248,51 +293,56 @@ void SgDNNLFCOp::Forward(const OpContext& ctx, } else { weight_scales_.resize(1); weight_scales_[0] = - GetQuantizeScale(cached_weight_.dtype(), cached_min_weight_, cached_max_weight_); + GetQuantizeScale(cached_weight_.dtype(), cached_weight_min_, cached_weight_max_); if (has_bias) { - float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_min_bias_, cached_max_bias_); - float bias_int32_rescale = data_scale_ * weight_scales_[0] / bias_scale; - // TODO(zhennan): dnnl has bug to handle INT_MAX in bias, so set the maximum value - // of bias to INT_MAX / 2. - float bias_max_rescale = - MaxValue() / 2 / MaxAbs(cached_min_bias_, cached_max_bias_) / bias_scale; - if (bias_int32_rescale > bias_max_rescale) { - // avoid overflow on bias - bias_int32_rescale = bias_max_rescale; - float weight_rescale = - bias_int32_rescale * bias_scale / data_scale_ / weight_scales_[0]; - int8_t* weight_ptr = weight.data().dptr(); - size_t weight_size = weight.shape().Size(); + if (cached_bias_.dtype() == mshadow::kInt8) { + float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_bias_min_, cached_bias_max_); + + float bias_int32_rescale = data_scale_ * weight_scales_[0] / bias_scale; + // TODO(zhennan): dnnl has bug to handle INT_MAX in bias, so set + // the maximum value of bias to INT_MAX / 2. + float bias_max_rescale = + MaxValue() / 2 / MaxAbs(cached_bias_min_, cached_bias_max_) / bias_scale; + if (bias_int32_rescale > bias_max_rescale) { + // avoid overflow on bias + bias_int32_rescale = bias_max_rescale; + float weight_rescale = + bias_int32_rescale * bias_scale / data_scale_ / weight_scales_[0]; + int8_t* weight_ptr = weight.data().dptr(); + size_t weight_size = weight.shape().Size(); #pragma omp parallel for num_threads(nthreads) - for (index_t i = 0; i < static_cast(weight_size); ++i) { - weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale); + for (index_t i = 0; i < static_cast(weight_size); ++i) { + weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale); + } + weight_scales_[0] *= weight_rescale; } - weight_scales_[0] *= weight_rescale; - } - NDArray bias = in_data[fullc::kBias]; - cached_bias_ = - NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, mshadow::kInt32); - int8_t* bias_ptr = bias.data().dptr(); - int32_t* quantized_bias_ptr = cached_bias_.data().dptr(); - size_t bias_size = bias.shape().Size(); + NDArray bias = in_data[fullc::kBias]; + cached_bias_ = + NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, mshadow::kInt32); + int8_t* bias_ptr = bias.data().dptr(); + int32_t* quantized_bias_ptr = cached_bias_.data().dptr(); + size_t bias_size = bias.shape().Size(); + #pragma omp parallel for num_threads(nthreads) - for (index_t i = 0; i < static_cast(bias_size); ++i) { - quantized_bias_ptr[i] = std::round(bias_ptr[i] * bias_int32_rescale); + for (index_t i = 0; i < static_cast(bias_size); ++i) { + quantized_bias_ptr[i] = std::round(bias_ptr[i] * bias_int32_rescale); + } } } } size_t num_channel = cached_weight_.shape()[0]; + float out_scale = 1.0f; if (fuse_requantize || dnnl_param.enable_float_output) { float tmp_scale_ = 1.0f; if (fuse_requantize) { if (dnnl_param.with_eltwise) { tmp_scale_ = 1.0 / data_scale_; full_param_.eltwise_param.scale = - GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_); + GetQuantizeScale(output.dtype(), cached_output_min_, cached_output_max_); } else { - tmp_scale_ = GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_) / - data_scale_; + out_scale = GetQuantizeScale(output.dtype(), cached_output_min_, cached_output_max_); + tmp_scale_ = out_scale / data_scale_; } } else { tmp_scale_ = 1.0 / data_scale_; @@ -314,26 +364,33 @@ void SgDNNLFCOp::Forward(const OpContext& ctx, mxnet_op::Kernel::Launch( s, 1, - &cached_min_output_, - &cached_max_output_, - &min_data, - &max_data, - &min_weight, - &max_weight); + &cached_output_min_, + &cached_output_max_, + &data_min, + &data_max, + &weight_min, + &weight_max); } else { mxnet_op::Kernel::Launch( s, 1, - &cached_min_output_, - &cached_max_output_, - &min_data, - &max_data, - &min_weight, - &max_weight); + &cached_output_min_, + &cached_output_max_, + &data_min, + &data_max, + &weight_min, + &weight_max); } full_param_.output_scales.resize(0); + out_scale = data_scale_ * weight_scales_[0]; } - } + + if (dnnl_param.with_sum && !dnnl_param.enable_float_output) { + float sum_in_scale = + GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_, cached_sum_max_); + full_param_.sum_scale = out_scale / sum_in_scale; + } + } // if (dnnl_param.quantized) fwd_.reset(new DNNLFullyConnectedForward(full_param_, ctx.is_train, @@ -357,10 +414,11 @@ void SgDNNLFCOp::Forward(const OpContext& ctx, weight_scales_, false); } else { - const auto def_weight_mem = weight.GetDNNLData(); + const auto def_weight_mem = static_cast(weight.GetDNNLData()); if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) { - cached_weight_ = NDArray(fwd_->fwd_pd.weights_desc()); - auto cached_weight_mem = cached_weight_.GetDNNLData(); + auto weight_desc = fwd_->fwd_pd.weights_desc(); + cached_weight_ = NDArray(weight_desc); + auto cached_weight_mem = static_cast(cached_weight_.GetDNNLData()); std::unordered_map args( {{DNNL_ARG_FROM, *def_weight_mem}, {DNNL_ARG_TO, *cached_weight_mem}}); DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(*def_weight_mem, *cached_weight_mem), @@ -368,17 +426,32 @@ void SgDNNLFCOp::Forward(const OpContext& ctx, } } - const auto data_mem = data.GetDNNLData(); + const auto data_mem = static_cast(data.GetDNNLData()); cached_data_mem_ = std::make_shared(data_mem->get_desc(), engine); args_[DNNL_ARG_SRC] = *cached_data_mem_; - args_[DNNL_ARG_WEIGHTS] = *cached_weight_.GetDNNLData(); + args_[DNNL_ARG_WEIGHTS] = *static_cast(cached_weight_.GetDNNLData()); if (has_bias) - args_[DNNL_ARG_BIAS] = *cached_bias_.GetDNNLData(); + args_[DNNL_ARG_BIAS] = *static_cast(cached_bias_.GetDNNLData()); args_[DNNL_ARG_DST] = *cached_out_mem_; initialized_ = true; } + if (dnnl_param.with_sum) { + const auto& output_mem = output.GetDNNLData(); + const auto& out_mem_desc = output_mem->get_desc(); + auto dst_mem_desc = fwd_->fwd_pd.dst_desc(); + if (out_mem_desc != dst_mem_desc) { + auto tmp_out_mem = output.GetDNNLDataReorder(dst_mem_desc); + dst_mem_desc.data.data_type = out_mem_desc.data.data_type; + dnnl_mem_ptr new_out_mem(new dnnl::memory( + dst_mem_desc, CpuEngine::Get()->get_engine(), output_mem->get_data_handle())); + DNNLStream::Get()->RegisterMem(new_out_mem); + DNNLMemoryCopy(*tmp_out_mem, new_out_mem.get()); + output = NDArray(new_out_mem); + } + } + if (reorder_data_) { data = data.Reorder2Default(); } @@ -392,10 +465,11 @@ void SgDNNLFCOp::Forward(const OpContext& ctx, DNNLStream::Get()->Submit(); if (dnnl_param.quantized && !dnnl_param.enable_float_output) { - float* min_output_ptr = out_data[quantized_fullc::kOutMin].data().dptr(); - float* max_output_ptr = out_data[quantized_fullc::kOutMax].data().dptr(); - *min_output_ptr = cached_min_output_; - *max_output_ptr = cached_max_output_; + float* output_min_ptr = out_data[out_min_index].data().dptr(); + float* output_max_ptr = out_data[out_max_index].data().dptr(); + + *output_min_ptr = cached_output_min_; + *output_max_ptr = cached_output_max_; } } @@ -450,23 +524,25 @@ static void SgDNNLFCParamParser(nnvm::NodeAttrs* attrs) { static std::vector SgDNNLFCListInputNames(const NodeAttrs& attrs) { auto const& full_param = nnvm::get(attrs.parsed); + auto const& dnnl_param = full_param.dnnl_param; std::vector input_names = DefaultSubgraphOpListInputs(attrs); - if (full_param.dnnl_param.quantized) { - bool channel_wise = false; - if (full_param.dnnl_param.channel_wise_quantize.has_value() && - full_param.dnnl_param.channel_wise_quantize) { - channel_wise = true; - } - input_names.emplace_back("min_data"); - input_names.emplace_back("max_data"); + if (dnnl_param.quantized) { + const bool channel_wise = + dnnl_param.channel_wise_quantize.has_value() && dnnl_param.channel_wise_quantize; + input_names.emplace_back("data_min"); + input_names.emplace_back("data_max"); if (!channel_wise) { - input_names.emplace_back("min_weight"); - input_names.emplace_back("max_weight"); + input_names.emplace_back("weight_min"); + input_names.emplace_back("weight_max"); if (!full_param.default_param.no_bias) { - input_names.emplace_back("min_bias"); - input_names.emplace_back("max_bias"); + input_names.emplace_back("bias_min"); + input_names.emplace_back("bias_max"); } } + if (dnnl_param.with_sum && !dnnl_param.enable_float_output) { + input_names.emplace_back("sum_min"); + input_names.emplace_back("sum_max"); + } } return input_names; } @@ -477,19 +553,19 @@ static std::vector SgDNNLFCListOutputNames(const NodeAttrs& attrs) if (full_param.dnnl_param.enable_float_output) return std::vector{"output"}; else - return std::vector{"output", "min_output", "max_output"}; + return std::vector{"output", "output_min", "output_max"}; } else { return std::vector{"output"}; } } template -static inline void FillBaseInputOutputInfo(const FullyConnectedParam& param, +static inline void FillBaseInputOutputInfo(const DNNLFCFullParam& param, std::vector* base_in_attrs, std::vector* base_out_attrs, std::vector* in_attrs, std::vector* out_attrs) { - auto base_num_inputs = param.no_bias ? 2 : 3; + auto base_num_inputs = FCInputIndex(param).GetBase(); base_out_attrs->push_back(out_attrs->at(0)); for (int i = 0; i < base_num_inputs; ++i) { @@ -504,8 +580,7 @@ static bool SgDNNLFCInferShape(const nnvm::NodeAttrs& attrs, if (full_param.dnnl_param.quantized) { mxnet::ShapeVector base_in_shapes; mxnet::ShapeVector base_out_shapes; - FillBaseInputOutputInfo( - full_param.default_param, &base_in_shapes, &base_out_shapes, in_shapes, out_shapes); + FillBaseInputOutputInfo(full_param, &base_in_shapes, &base_out_shapes, in_shapes, out_shapes); bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes); for (size_t i = 0; i < in_shapes->size(); ++i) { @@ -531,26 +606,43 @@ static bool SgDNNLFCInferType(const nnvm::NodeAttrs& attrs, std::vector* out_types) { auto const& full_param = nnvm::get(attrs.parsed); if (full_param.dnnl_param.quantized) { - bool channel_wise = false; - if (full_param.dnnl_param.channel_wise_quantize.has_value() && - full_param.dnnl_param.channel_wise_quantize) { - channel_wise = true; - } - size_t base_num_inputs = full_param.default_param.no_bias ? 2 : 3; - CHECK(in_types->at(0) == mshadow::kInt8 || in_types->at(0) == mshadow::kUint8) - << "QuantizedFullyConnected only supports int8/uint8 input, while " << in_types->at(0) - << " is given."; - for (size_t i = 1; i < in_types->size(); ++i) { - if (channel_wise) { - TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); - } else { - if (i < base_num_inputs) { - TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kInt8); + const bool channel_wise = full_param.dnnl_param.channel_wise_quantize.has_value() && + full_param.dnnl_param.channel_wise_quantize; + const FCInputIndex idx(full_param); + + CHECK(in_types->at(idx.data) == mshadow::kInt8 || in_types->at(idx.data) == mshadow::kUint8) + << "QuantizedFullyConnected data input only supports int8/uint8, while " + << in_types->at(idx.data) << " is given."; + if (channel_wise) { + TYPE_ASSIGN_CHECK(*in_types, idx.weight, mshadow::kFloat32); + if (idx.IsBiasExist()) { + TYPE_ASSIGN_CHECK(*in_types, idx.bias, mshadow::kFloat32); + } + } else { + TYPE_ASSIGN_CHECK(*in_types, idx.weight, mshadow::kInt8); + if (idx.IsBiasExist()) { + if (in_types->at(idx.bias) == -1) { + TYPE_ASSIGN_CHECK(*in_types, idx.bias, mshadow::kInt32); } else { - TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); + CHECK(in_types->at(idx.bias) == mshadow::kInt8 || + in_types->at(idx.bias) == mshadow::kInt32) + << "QuantizedFullyConnected bias input only supports int8/int32, while " + << in_types->at(idx.bias) << " is given."; } } } + if (idx.IsSumExist()) { + if (full_param.dnnl_param.enable_float_output) { + TYPE_ASSIGN_CHECK(*in_types, idx.sum, mshadow::kFloat32); + } else { + CHECK(in_types->at(idx.sum) == mshadow::kInt8 || in_types->at(idx.sum) == mshadow::kUint8) + << "QuantizedFullyConnected sum input only supports int8/uint8, while " + << in_types->at(idx.sum) << " is given."; + } + } + for (size_t i = idx.data_min; i < in_types->size(); ++i) { + TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); + } if (full_param.dnnl_param.enable_float_output) { TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); @@ -583,8 +675,7 @@ static bool SgDNNLFCStorageType(const nnvm::NodeAttrs& attrs, if (full_param.dnnl_param.quantized) { std::vector base_in_attrs; std::vector base_out_attrs; - FillBaseInputOutputInfo( - full_param.default_param, &base_in_attrs, &base_out_attrs, in_attrs, out_attrs); + FillBaseInputOutputInfo(full_param, &base_in_attrs, &base_out_attrs, in_attrs, out_attrs); bool ret = DefaultSubgraphOpStorageType( attrs, dev_mask, dispatch_mode, &base_in_attrs, &base_out_attrs); @@ -606,6 +697,15 @@ static bool SgDNNLFCStorageType(const nnvm::NodeAttrs& attrs, } } +std::vector> SgDNNLFCInplaceOption(const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.dnnl_param.with_sum) { + return std::vector>{{FCInputIndex(param).sum, 0}}; + } else { + return std::vector>(); + } +} + static OpStatePtr CreateSgDNNLFCState(const nnvm::NodeAttrs& attrs, Context ctx, const mxnet::ShapeVector& in_shapes, @@ -641,13 +741,16 @@ static bool SgDNNLAvoidFCQuantizeInput(const NodeAttrs& attrs, const std::string quantize_granularity) { auto const& full_param = nnvm::get(attrs.parsed); std::unordered_set avoid_indexes; + FCInputIndex idx(full_param); if (quantize_granularity == "channel-wise") { avoid_indexes.insert(fullc::kWeight); // weight if (!full_param.default_param.no_bias) { avoid_indexes.insert(fullc::kBias); // bias } } - + if (idx.IsSumInputFloat()) { + avoid_indexes.insert(idx.sum); + } return avoid_indexes.count(index_to_check); } @@ -656,17 +759,7 @@ NNVM_REGISTER_OP(_sg_onednn_fully_connected) .describe(R"code(_sg_onednn_fully_connected)code" ADD_FILELINE) .set_num_inputs([](const NodeAttrs& attrs) { auto const& full_param = nnvm::get(attrs.parsed); - auto num_inputs = full_param.default_param.no_bias ? 2 : 3; - if (full_param.dnnl_param.quantized) { - if (full_param.dnnl_param.channel_wise_quantize.has_value() && - full_param.dnnl_param.channel_wise_quantize) { - return num_inputs + 2; // min_data, max_data - } else { - return num_inputs * 3; - } - } else { - return num_inputs; - } + return FCInputIndex(full_param).GetTotal(); }) .set_num_outputs([](const NodeAttrs& attrs) { auto const& full_param = nnvm::get(attrs.parsed); @@ -691,6 +784,7 @@ NNVM_REGISTER_OP(_sg_onednn_fully_connected) }) .set_attr("FMutateInputs", DefaultSubgraphOpMutableInputs) .set_attr("key_var_num_args", "num_args") + .set_attr("FInplaceOption", SgDNNLFCInplaceOption) .set_attr("FQuantizable", [](const NodeAttrs& attrs) { return QuantizeType::kMust; }) .set_attr("FQuantizedOp", SgDNNLFCQuantizedOp) diff --git a/src/operator/subgraph/dnnl/dnnl_fc_property.h b/src/operator/subgraph/dnnl/dnnl_fc_property.h index 9884dc7168ee..48481cfd20a8 100644 --- a/src/operator/subgraph/dnnl/dnnl_fc_property.h +++ b/src/operator/subgraph/dnnl/dnnl_fc_property.h @@ -193,6 +193,9 @@ class SgDNNLFCProperty : public SubgraphProperty { auto& sub_name = node->op()->name; if (sub_name == "FullyConnected") { node_name << "fully_connected_"; + if (HasAttr("quantize") && GetAttr("quantize")) { + n->attrs.dict["first_quantization_pass"] = "True"; + } } else if (SupportDNNLFCEltwiseFusion(sub_name)) { node_name << "eltwise_"; n->attrs.dict["with_eltwise"] = "True"; diff --git a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h new file mode 100644 index 000000000000..4af89c9298f6 --- /dev/null +++ b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + \file + \brief For fusing FullyConnected operator with element-wise add. + + Element-wise add operator is replaced by DNNL FC "sum" post operator. + It adds FC results to existing values in output. For quantized integer version + this output is scaled to the proper range. +*/ + +#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_ +#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_ +#if MXNET_USE_ONEDNN == 1 + +#include +#include +#include +#include +#include + +#include "../../tensor/matrix_op-inl.h" +#include "../common.h" +#include "dnnl_fc-inl.h" +#include "dnnl_subgraph_base-inl.h" + +namespace mxnet { +namespace op { + +inline bool EndsWith(std::string const& value, std::string const& ending) { + if (ending.size() > value.size()) { + return false; + } else { + return std::equal(ending.rbegin(), ending.rend(), value.rbegin()); + } +} + +class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 { + private: + /*! \brief pattern match status */ + enum SelectStatus { + kFail = 0, + kStart, + kSuccess, + }; + + bool quantized_; + SelectStatus status_ = kFail; + std::vector matched_list_; + + public: + explicit SgDNNLFCSumFuseSelector(bool quantized) : quantized_(quantized) {} + + bool Select(const BiDirectedNode& seed_node, + const std::shared_ptr& node_attr) override { + const auto n = seed_node.node; + if (n->op() == Op::Get("_sg_onednn_fully_connected") && SupportDNNLAttr(node_attr) && + (seed_node.outputs.size() == 1)) { + auto const& fc_param = nnvm::get(n->attrs.parsed); + if ((!quantized_ && !fc_param.dnnl_param.first_quantization_pass) || + (fc_param.dnnl_param.quantized && !fc_param.dnnl_param.with_eltwise)) { + // Start subgraph when fusing for floats (quantized_ is false for DNNL backend) or + // when FC is already quantized (second pass for DNNL_QUANTIZE) but not already fuzed + // with elemwise operator. + status_ = kStart; + matched_list_.clear(); + matched_list_.push_back(&seed_node); + return true; + } + } + return false; + } + + bool SelectInput(const BiDirectedNode& cur_node, const BiDirectedNode& input_node) override { + return false; + } + + bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode& output_node) override { + const auto cur_n = cur_node.node; + const auto output_n = output_node.node; + if (status_ == kFail || status_ == kSuccess || output_n->is_variable()) { + return false; + } + // If n isn't the last matched node, then we encoutered an internal + // branch, we should pop out the node behind n and stop fusion. + if (matched_list_.back() != &cur_node) { + if (std::find(matched_list_.begin(), matched_list_.end(), &cur_node) != matched_list_.end()) { + while (matched_list_.back() != &cur_node) { + matched_list_.pop_back(); + } + } + status_ = kSuccess; + return false; + } + + switch (status_) { + case kStart: + // Find _contrib_quantized_elemwise_add or elemwise_add + if (EndsWith(output_n->op()->name, "elemwise_add")) { + if (quantized_) { + auto const& fc_param = nnvm::get(cur_n->attrs.parsed); + if (!fc_param.dnnl_param.enable_float_output) { + // For quantized graph, when FC floating point output is not enabled + // elementwise add must also be quantized (min and max value have to be already stored + // in elementwise add). + CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1); + } + } + matched_list_.push_back(&output_node); + status_ = kSuccess; + return true; + } + default: + status_ = kFail; + return false; + } + } + + std::vector Filter(const std::vector& candidates) override { + if (status_ == kSuccess) { + return candidates; + } else { + return std::vector(0); + } + } + + void Reset() override { + CHECK_GE(matched_list_.size(), 1); + auto new_selector = SgDNNLFCSumFuseSelector(quantized_); + new_selector.Select(*matched_list_[0], nullptr); + *this = new_selector; + } +}; + +class SgDNNLFCSumFuseProperty : public SubgraphProperty { + public: + SgDNNLFCSumFuseProperty() {} + + static SubgraphPropertyPtr Create() { + static const std::string& name = "DNNL fuse FullyConnected with sum"; + auto property = std::make_shared(); + property->SetAttr("property_name", name); + property->SetAttr("inference_only", true); + if (dmlc::GetEnv("MXNET_DISABLE_DNNL_FC_SUM", 0)) { + property->SetAttr("disable", true); + } + return property; + } + + nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym, + const int subgraph_id = 0) const override { + nnvm::ObjectPtr fc_node = nullptr; + nnvm::ObjectPtr ew_add_node = nullptr; + + DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) { + if (node->is_variable()) { + return; + } + auto& sub_name = node->op()->name; + if (sub_name == "_sg_onednn_fully_connected") { + fc_node = node; + } else if (EndsWith(sub_name, "elemwise_add")) { + ew_add_node = node; + } + }); + + CHECK_NOTNULL(fc_node); + if (ew_add_node != nullptr) { + CHECK_NOTNULL(fc_node->attrs.subgraphs[0]); + auto subgraph_output_node = fc_node->attrs.subgraphs[0]->outputs[0].node; + nnvm::Symbol new_sym; + // Create a new elemwise_add node to not alter the original one. + // It is needed in subgraph to properly calculate InferShape. + nnvm::ObjectPtr n = nnvm::Node::Create(); + n->attrs.op = Op::Get("elemwise_add"); + n->attrs.name = ew_add_node->attrs.name; + + if (ew_add_node->inputs[0].node == fc_node) { + n->inputs.emplace_back(subgraph_output_node); + n->inputs.emplace_back(ew_add_node->inputs[1]); + } else { + n->inputs.emplace_back(ew_add_node->inputs[0]); + n->inputs.emplace_back(subgraph_output_node); + } + new_sym.outputs.emplace_back(n); + fc_node->attrs.subgraphs.clear(); + fc_node->attrs.subgraphs.emplace_back(std::make_shared(new_sym)); + fc_node->attrs.dict["with_sum"] = "True"; + fc_node->attrs.dict.erase("first_quantization_pass"); // Removed as not needed any longer + fc_node->op()->attr_parser(&(fc_node->attrs)); + } + return fc_node; + } + + SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override { + bool quantized = HasAttr("quantize") ? GetAttr("quantize") : false; + auto selector = std::make_shared(quantized); + return selector; + } + + void ConnectSubgraphOutputs(const nnvm::ObjectPtr n, + std::vector* output_entries) const override { + // Connect all extern output entries to output[0] + for (size_t i = 0; i < output_entries->size(); ++i) { + auto entry_ptr = output_entries->at(i); + *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; + } + } + + void ConnectSubgraphInputs(const nnvm::ObjectPtr n, + std::vector* input_entries, + std::vector* orig_input_entries) const override { + auto sym = n->attrs.subgraphs[0]; + auto const& fc_param = nnvm::get(n->attrs.parsed); + std::unordered_set node_sets; + DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr& node) { + if (node->is_variable()) { + return; + } + node_sets.insert(node.get()); + if (EndsWith(node->op()->name, "elemwise_add")) { + const size_t base_inputs = fc_param.default_param.no_bias ? 3 : 4; + // Make sure fc output is the left operand of the add operator, if not: + // - swap inputs of add operator + // - switch add operands sequence to ensure that + // the tensor (sum_tensor) to which FC output is added is the last input. + if (node_sets.count(node->inputs[1].node.get())) { + // Example of input_entries reordering for channel-wise quantized graph: + // sum_tensor.data --> fc.data + // fc.data --> fc.weight0 + // fc.weight0 --> fc.bias0 + // fc.bias0 --> sum_tensor.data + // fc_out.min --> fc_out.min + // fc_out.max --> fc_out.max + // sum_tensor.min --> sum_tensor.min + // sum_tensor.max --> sum_tensor.max + std::swap(node->inputs[0], node->inputs[1]); + std::rotate(input_entries->begin(), + input_entries->begin() + 1, + input_entries->begin() + base_inputs); + std::rotate(orig_input_entries->begin(), + orig_input_entries->begin() + 1, + orig_input_entries->begin() + base_inputs); + } else { + // Example of input_entries reordering for channel-wise quantized graph: + // fc.data --> fc.data + // fc.weight0 --> fc.weight0 + // fc.bias0 --> fc.bias0 + // fc_out.min --> sum_tensor.data + // fc_out.max --> fc_out.min + // sum_tensor.data --> fc_out.max + // sum_tensor.min --> sum_tensor.min + // sum_tensor.max --> sum_tensor.max + const int not_rotated_end = + (fc_param.dnnl_param.quantized && !fc_param.dnnl_param.enable_float_output) ? 2 : 0; + + std::rotate(input_entries->begin() + base_inputs - 1, + input_entries->end() - 1 - not_rotated_end, + input_entries->end() - not_rotated_end); + std::rotate(orig_input_entries->begin() + base_inputs - 1, + orig_input_entries->end() - 1 - not_rotated_end, + orig_input_entries->end() - not_rotated_end); + } + } + }); + n->inputs = *orig_input_entries; + } +}; + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_ diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h index cddf4b447810..886b06878bdd 100644 --- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h +++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h @@ -44,6 +44,9 @@ const std::set support_req_fusion_op = {"_contrib_quantized_elemwis "_sg_onednn_selfatt_qk", "_sg_onednn_selfatt_valatt", "_sg_onednn_batch_dot"}; + +const std::set no_enable_float_output = {Op::Get("_contrib_quantized_elemwise_add"), + Op::Get("_sg_onednn_conv")}; } // namespace class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 { @@ -109,7 +112,8 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 { if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { matched_list.emplace_back(&new_node); status = SelectStatus::kRequantize; - if (raw_node->op() == Op::Get("_sg_onednn_conv")) { + if ((raw_node->op() == Op::Get("_sg_onednn_conv")) || + (raw_node->op() == Op::Get("_contrib_quantized_elemwise_add"))) { status = SelectStatus::kSuccess; } return true; @@ -209,7 +213,7 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty { // When only fused quantized operator and requantize, set min/max_cablib_range, // When fused quantized operator + requantize + dequantize, set dequantize flag to true. - if (dequantize_node != nullptr) { + if ((dequantize_node != nullptr) && (no_enable_float_output.count(fuse_node->op()) == 0)) { fuse_node->attrs.dict["enable_float_output"] = "True"; } else { fuse_node->attrs.dict["min_calib_range"] = diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc index 8f8fc446808d..b3b23e1a911d 100644 --- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc +++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc @@ -28,6 +28,7 @@ #include "dnnl_post_quantize_property.h" #include "dnnl_transformer_qk_property.h" #include "dnnl_transformer_valatt_property.h" +#include "dnnl_fc_sum_fuse.h" namespace mxnet { namespace op { @@ -43,6 +44,7 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBNReLUProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerQKProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerValAttProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBatchDotProperty); +MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLFCSumFuseProperty); MXNET_REGISTER_SUBGRAPH_BACKEND(ONEDNN_QUANTIZE).set_attr("context", Context::CPU()); @@ -55,6 +57,8 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLBatchDotProperty) .set_attr("quantize", true); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLPostQuantizeAlignScaleProperty); +MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLFCSumFuseProperty) + .set_attr("quantize", true); } // namespace op } // namespace mxnet diff --git a/tests/python/dnnl/subgraphs/subgraph_common.py b/tests/python/dnnl/subgraphs/subgraph_common.py index 3ed526ca56d5..be2adb9e2f03 100644 --- a/tests/python/dnnl/subgraphs/subgraph_common.py +++ b/tests/python/dnnl/subgraphs/subgraph_common.py @@ -111,8 +111,27 @@ def check_qsym_scale_align(qsym): assert max_calib_range == v['max_calib_range'] -def check_quantize(net_original, data_shape, out_type, name='conv', - check_calibration=True, check_scale_align=False): +def check_fusion_parameter(sym, attrs_dict): + for name, attrs in attrs_dict.items(): + if name in config: + op_name = config[name][OP_NAME] + else: + op_name = name + assert ''.join(sym.get_internals().list_outputs()).find(op_name) != -1 + if len(attrs): + found = False + for k, v in sym.attr_dict().items(): + if k.find('_quantize') != -1: + continue + if k.find(op_name) != -1: + found = True + for attr_name, attr_value in attrs.items(): + assert v[attr_name].lower() == attr_value.lower() + assert found + +def check_quantize(net_original, data_shapes, out_type, name='conv', + check_calibration=True, check_scale_align=False, quantize_mode='full', + attrs_dict={}): quantize_granularity_list = ['tensor-wise'] if name == 'fc': quantize_granularity_list += ['channel-wise'] @@ -122,92 +141,108 @@ def check_quantize(net_original, data_shape, out_type, name='conv', net_original.initialize(init=mx.init.Normal(0.5), force_reinit=True) min_value = -1 if out_type != 'uint8' else 0 - data = mx.np.random.uniform(min_value, 1.0, size=data_shape, dtype='float32', ctx=mx.current_device()) - - outputs = net_original(data) + one_shape = isinstance(data_shapes, tuple) + if one_shape: + # replace one shape with list of shapes with one element inside to follow later the same schema + data_shapes=[data_shapes] + data = [] + for shape in data_shapes: + data.append(mx.np.random.uniform(min_value, 1.0, size=shape, dtype='float32', device=mx.cpu())) + + outputs = net_original(*data) for output in outputs: output.wait_to_read() ref_out = outputs - calib_data = mx.gluon.data.DataLoader(data, batch_size=1) + dataArray= mx.gluon.data.ArrayDataset(*data) + + calib_data = mx.gluon.data.DataLoader(dataArray, batch_size=1) for quantize_granularity in quantize_granularity_list: qnet = quantization.quantize_net(net_original, - ctx=mx.current_device(), + device=mx.cpu(), exclude_layers=None, exclude_operators=None, quantized_dtype=out_type, calib_mode='naive', calib_data=calib_data, num_calib_batches=1, - quantize_mode='full', + quantize_mode=quantize_mode, quantize_granularity=quantize_granularity) qsym, _ = qnet.export(None) + check_fusion_parameter(qsym, attrs_dict) if check_calibration: check_qsym_calibrated(qsym, out_type, name=name) if check_scale_align: check_qsym_scale_align(qsym) - quantized_out = qnet(data) + quantized_out = qnet(*data) for i in range(len(ref_out)): min_range = mx.np.min(ref_out[i]).item() max_range = mx.np.max(ref_out[i]).item() atol = 0.1 * max(abs(min_range), abs(max_range)) - assert_almost_equal_with_err(quantized_out.asnumpy(), ref_out.asnumpy(), rtol=0.1, atol=atol, etol=0.2) + assert_almost_equal_with_err(quantized_out.asnumpy(), ref_out.asnumpy(), + rtol=0.1, atol=atol, etol=0.2) -def check_fusion(net_original, data_shape, attrs_dict, check_fp32_fusion=True, check_quantization=True, - out_types=['uint8', 'int8', 'auto'], dedup_subgraph=True): +def check_fusion(net_original, data_shapes, attrs_dict, check_fp32_fusion=True, + check_quantization=True, out_types=['uint8', 'int8', 'auto'], dedup_subgraph=True, + quantize_mode='full'): net_original.initialize() net_original.hybridize(static_alloc=False, static_shape=False) - data = mx.np.random.uniform(size=data_shape, dtype='float32', ctx=mx.current_device()) - net_original(data) + one_shape = isinstance(data_shapes, tuple) + data_min = -1.0 + data_max = 1.0 + + if one_shape: + # replace one shape with list of shapes with one element to follow later the same schema + data_shapes=[data_shapes] + data = [] + for shape in data_shapes: + data.append(mx.np.random.uniform(size=shape, dtype='float32', device=mx.cpu(), + low=data_min, high=data_max)) + net_original(*data) net_fusion = copy.copy(net_original) sym, params = net_original.export(None) if check_fp32_fusion: - data_min = -1.0 - data_max = 1.0 if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1: check_quantization = False data_min = 0 sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=dedup_subgraph, skip_infer=True) - for name, attrs in attrs_dict.items(): - if name in config: - op_name = config[name][OP_NAME] - else: - op_name = name - assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1 - if len(attrs): - found = False - for k, v in sym_sg.attr_dict().items(): - if k.find(op_name) != -1: - found = True - for attr_name, attr_value in attrs.items(): - assert v[attr_name].lower() == attr_value.lower() - assert found - - data = mx.np.random.uniform(size=data_shape, low=data_min, high=data_max) - out_unfused = net_original(data) - - net_fusion.optimize_for(data, backend=SG_PASS_NAME) - out_fused = net_fusion(data) + check_fusion_parameter(sym_sg, attrs_dict) + if data_min == 0 and mx.npx.is_np_default_dtype(): + # regenerate inputs if they have different range or data type + data = [] + for shape in data_shapes: + data.append(mx.np.random.uniform(size=shape, device=mx.cpu(), low=data_min, high=data_max)) + out_unfused = net_original(*data) + + net_fusion.optimize_for(*data, backend=SG_PASS_NAME) + out_fused = net_fusion(*data) assert_almost_equal(out_unfused.asnumpy(), out_fused.asnumpy(), rtol=1e-3, atol=1e-1) if check_quantization: # fp32 to int8 for out_type in out_types: - check_quantize(net_original, data_shape, out_type, name=name) + check_quantize(net_original, data_shapes, out_type, name=list(attrs_dict.keys())[0], + quantize_mode=quantize_mode, attrs_dict=attrs_dict) def check_neg_fusion(net_original, attrs_name=None, excluded_attrs=None, - data_shapes=(4,4,10,10), name='conv'): + data_shapes=[(4,4,10,10)], name='conv'): op_name = config[name][OP_NAME] + one_shape = isinstance(data_shapes, tuple) + if one_shape: + # replace one shape with list of shapes with one element to follow later the same schema + data_shapes = [data_shapes] + data = [] + for shape in data_shapes: + data.append(mx.np.random.uniform(size=shape)) - data_nd = mx.np.random.uniform(size=data_shapes) net_original.initialize() net_original.hybridize() - net_original(data_nd) + net_original(*data) sym, _ = net_original.export(None) sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) @@ -218,4 +253,41 @@ def check_neg_fusion(net_original, attrs_name=None, excluded_attrs=None, for attr in attrs_name: assert v[attr] == 'true' for exc_attr in excluded_attrs: - assert exc_attr not in v.keys() + assert exc_attr not in v.keys(), exc_attr + " atribute shouldn't exist" + + + +def check_neg_fusion_quantized(net_original, attrs_name=None, excluded_attrs=None, + data_shapes=[(4,4,10,10)], name='conv'): + op_name = config[name][OP_NAME] + net_original.initialize(init=mx.init.Normal(0.5), force_reinit=True) + one_shape = isinstance(data_shapes, tuple) + if one_shape: + # replace one shape with list of shapes with one element inside to follow later the same schema + data_shapes=[data_shapes] + data = [] + for shape in data_shapes: + data.append(mx.np.random.uniform(size=shape, dtype='float32', device=mx.cpu())) + + dataArray= mx.gluon.data.ArrayDataset(*data) + calib_data = mx.gluon.data.DataLoader(dataArray, batch_size=1) + + qnet = quantization.quantize_net(net_original, + device=mx.cpu(), + exclude_layers=None, + exclude_operators=None, + quantized_dtype='int8', + calib_mode='naive', + calib_data=calib_data, + num_calib_batches=1, + quantize_mode='full', + quantize_granularity='tensor-wise') + qsym, _ = qnet.export(None) + attrs_dict = qsym.attr_dict() + for k, v in attrs_dict.items(): + if k.find(op_name) != -1: + for attr in attrs_name: + assert v[attr] == 'true' + for exc_attr in excluded_attrs: + assert exc_attr not in v.keys(), exc_attr + " atribute shouldn't exist" + diff --git a/tests/python/dnnl/subgraphs/test_conv_subgraph.py b/tests/python/dnnl/subgraphs/test_conv_subgraph.py index 6b6169bbed9d..e7dac8f8be59 100644 --- a/tests/python/dnnl/subgraphs/test_conv_subgraph.py +++ b/tests/python/dnnl/subgraphs/test_conv_subgraph.py @@ -91,7 +91,7 @@ def forward(self, x): attr = {'conv': {'with_sum': 'true'}} net = ConvAdd(use_bias=use_bias) - check_fusion(net, data_shape, attr) + check_fusion(net, data_shape, attr, check_quantization=False) @mx.util.use_np @@ -112,14 +112,14 @@ def forward(self, x): attr = {'conv': {'with_sum': 'true'}} net = ConvAdd(use_bias=True) - check_fusion(net, data_shape, attr) + check_fusion(net, data_shape, attr, check_quantization=False) @mx.util.use_np @pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize('alg,quantize', [ ("relu", False), #TODO(bgawrych): investigate - ("sigmoid", True), + ("sigmoid", False), ("log_sigmoid", False), ("mish", False), ("tanh", False), #TODO(bgawrych): investigate @@ -162,11 +162,11 @@ def forward(self, x): @pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize('alg,quantize', [ ("relu", True), - ("sigmoid", True), - ("log_sigmoid", True), - ("mish", True), - ("tanh", True), - ("softrelu", True), + ("sigmoid", False), + ("log_sigmoid", False), + ("mish", False), + ("tanh", False), + ("softrelu", False), ("relu6", True), ("leakyrelu", True), ("gelu", True) @@ -200,14 +200,14 @@ def forward(self, x): @mx.util.use_np @pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize('alg,quantize', [ - ("relu", True), - ("sigmoid", True), - ("log_sigmoid", True), - ("mish", True), - ("tanh", True), + ("relu", False), + ("sigmoid", False), + ("log_sigmoid", False), + ("mish", False), + ("tanh", False), #("softrelu", True), #TODO(bgawrych): failing fusion check - difference in random single element - ("relu6", True), - ("leakyrelu", True), + ("relu6", False), + ("leakyrelu", False), ("gelu", False) #TODO: for True we get assert instead of not fusing pattern ]) @pytest.mark.parametrize('use_bias', [True, False]) @@ -321,11 +321,11 @@ def infer_shape(self, x, *args): @pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize('alg,quantize', [ ("relu", True), - ("sigmoid", True), - ("log_sigmoid", True), - ("mish", True), - ("tanh", True), - ("softrelu", True), + ("sigmoid", False), + ("log_sigmoid", False), + ("mish", False), + ("tanh", False), + ("softrelu", False), ("relu6", True), ("leakyrelu", True), ("gelu", True) diff --git a/tests/python/dnnl/subgraphs/test_fc_subgraph.py b/tests/python/dnnl/subgraphs/test_fc_subgraph.py index 223a55df0f96..c63bb9ae9ce5 100644 --- a/tests/python/dnnl/subgraphs/test_fc_subgraph.py +++ b/tests/python/dnnl/subgraphs/test_fc_subgraph.py @@ -17,7 +17,7 @@ import mxnet as mx import pytest -from subgraph_common import check_fusion, check_neg_fusion +from subgraph_common import check_fusion, check_neg_fusion, check_neg_fusion_quantized from subgraph_common import CustomNormalInit, DATA_SHAPE, TailNegBlock from mxnet.contrib import quantization from mxnet.gluon import nn @@ -89,9 +89,11 @@ def forward(self, x): out = mx.np.clip(fc_out, 0, 1.0) return out + not_quant_fuze = ['sigmoid', 'log_sigmoid', 'softrelu', 'tanh', 'mish', 'square', 'square_root', + 'exp'] attrs = {'fc': {'with_eltwise': 'true'}} net = FCEltwise(use_bias, flatten, alg) - check_fusion(net, data_shape, attrs, check_quantization=flatten) + check_fusion(net, data_shape, attrs, check_quantization=flatten and not alg in not_quant_fuze) @mx.util.use_np @@ -148,7 +150,7 @@ def forward(self, x): conv1 = mx.npx.fully_connected(x, num_hidden=64, weight=self.weight.data(x.device), no_bias=False, bias=self.bias.data(x.device)) return conv1 - + def infer_shape(self, x, *args): self.weight.shape = (64, x.shape[x.ndim-1]) self.bias.shape = (64,) @@ -232,3 +234,139 @@ def forward(self, x): 'sg_onednn_fully_connected_eltwise_1' : {'with_eltwise': 'true'}} net = FCIdentityEltwise(identity_node) check_fusion(net, data_shape, attrs, check_quantization=False) + + +def function_fc_add(data_shape, add_op, quantize_mode, fc_out_add, flatten, relu, out_type): + class FCWithSumExample(nn.HybridBlock): + def __init__(self, num_hidden, add_op, fc_out_add, **kwargs): + super(FCWithSumExample, self).__init__(**kwargs) + self.fca = nn.Dense(units=num_hidden, flatten=flatten) + self.elemwise_add = (add_op == 'elemwise_add') + self.fc_out_as_rhs = (fc_out_add == 'rhs') + self.relu = (relu == 'leaky_relu') + + def forward(self, data1a, data2): + fc_out = self.fca(data1a) + if self.relu: + fc_out = mx.npx.leaky_relu(fc_out, act_type='gelu') + if self.fc_out_as_rhs: + if self.elemwise_add: + sum1 = mx.nd.elemwise_add(data2.as_nd_ndarray(), fc_out.as_nd_ndarray()).as_np_ndarray() + else: + sum1 = data2 + fc_out + else: + if self.elemwise_add: + sum1 = mx.nd.elemwise_add(fc_out.as_nd_ndarray(), data2.as_nd_ndarray()).as_np_ndarray() + else: + sum1 = fc_out + data2 + return sum1 + + attrs = {'fc': {'with_sum': 'true'}} + if quantize_mode is not None: + attrs['fc']['quantized'] = 'true' + if quantize_mode == 'smart': + attrs['fc']['enable_float_output'] = 'true' + num_hidden=10 + net = FCWithSumExample(num_hidden, add_op, fc_out_add) + if flatten: + data_shapes = [data_shape, (data_shape[0], num_hidden)] + else: + data_shapes = [data_shape, (*data_shape[0:-1], num_hidden)] + check_fusion(net, data_shapes, attrs, + out_types=[out_type], + check_fp32_fusion=(quantize_mode is None), + check_quantization=(quantize_mode is not None) and flatten, + quantize_mode=quantize_mode) + +@mx.util.use_np +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('relu', ['noleaky_re', 'leaky_relu']) +@pytest.mark.parametrize('flatten', ['flat', 'nofl']) +@pytest.mark.parametrize('fc_out_add', ['lhs', 'rhs']) +@pytest.mark.parametrize('add_op', ['elemwise_add']) +def test_fc_add(data_shape, add_op, fc_out_add, flatten, relu): + function_fc_add(data_shape, add_op, None, fc_out_add, flatten=='flat', relu, None) + +@mx.util.use_np +@pytest.mark.seed(1234) # Seed set because the test is not robust enough to operate on random data +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('quantize_mode', ['full', 'smart']) +@pytest.mark.parametrize('out_type', ['int8', 'auto']) +@pytest.mark.parametrize('fc_out_add', ['lhs', 'rhs']) +@pytest.mark.parametrize('add_op', ['elemwise_add']) +def test_fc_add_quantized(data_shape, add_op, quantize_mode, fc_out_add, out_type): + function_fc_add(data_shape, add_op, quantize_mode, fc_out_add, True, 'noleaky_re', out_type) + + +class NegFCAdd(nn.HybridBlock): + # + # data --------------------------> 'add_op' ------------> + # / \ + # sg_oned_dnn_fully_connected ----> npi_add --> + # \ / + # npi_multiply_scalar --> + def __init__(self, num_hidden, add_op, fc_out_add, scaled_fc_out, flatten, **kwargs): + super(NegFCAdd, self).__init__(**kwargs) + self.fca = nn.Dense(units=num_hidden, flatten=flatten) + self.elemwise_add = (add_op == 'elemwise_add') + self.fc_out_as_rhs = (fc_out_add == 'rhs') + self.scaled_fc_out_as_rhs = (scaled_fc_out == 's_rhs') + + def forward(self, data1a, data2): + fc_out = self.fca(data1a) + scaled_fc_out = fc_out * 200.0 + if self.fc_out_as_rhs: + if self.elemwise_add: + sum1 = mx.nd.elemwise_add(data2.as_nd_ndarray(), fc_out.as_nd_ndarray()).as_np_ndarray() + else: + sum1 = data2 + fc_out + else: + if self.elemwise_add: + sum1 = mx.nd.elemwise_add(fc_out.as_nd_ndarray(), data2.as_nd_ndarray()).as_np_ndarray() + else: + sum1 = fc_out + data2 + if self.scaled_fc_out_as_rhs: + sum2 = sum1 + scaled_fc_out + else: + sum2 = scaled_fc_out + sum1 + return sum2 + +@mx.util.use_np +@pytest.mark.parametrize('add_op', ['elemwise_add']) +@pytest.mark.parametrize('data_shape', [DATA_SHAPE[0]]) +@pytest.mark.parametrize('flatten', ['flat', 'nofl']) +@pytest.mark.parametrize('fc_out_add', ['lhs', 'rhs']) +@pytest.mark.parametrize('scaled_fc_out', ['s_lhs', 's_rhs']) +def test_neg_fc_add(data_shape, add_op, flatten, fc_out_add, scaled_fc_out): + ''' + Test if FullyConnected operator which output is not used for only one 'add_op' input is not fused. + See NegFCAdd for used graph example + ''' + flatten = (flatten == 'flat') + num_hidden = 10 + net = NegFCAdd(num_hidden, add_op, fc_out_add, scaled_fc_out, flatten) + if flatten: + data_shapes = [data_shape, (data_shape[0], num_hidden)] + else: + data_shapes = [data_shape, (*data_shape[0:-1], num_hidden)] + attrs = [] + excluded_attrs = ['with_sum'] + check_neg_fusion(net, attrs, excluded_attrs, data_shapes, name='fc') + +@mx.util.use_np +@pytest.mark.parametrize('add_op', ['elemwise_add']) +@pytest.mark.parametrize('data_shape', [DATA_SHAPE[1]]) +@pytest.mark.parametrize('fc_out_add', ['lhs', 'rhs']) +@pytest.mark.parametrize('scaled_fc_out', ['s_lhs', 's_rhs']) +def test_neg_fc_add_quantized(data_shape, add_op, fc_out_add, scaled_fc_out): + ''' + Test if FullyConnected operator which output is not used for only one 'add_op' input + is not fused for quantized model. + See NegFCAdd for used graph example. + ''' + num_hidden = 10 + net = NegFCAdd(num_hidden, add_op, fc_out_add, scaled_fc_out, True) + data_shapes = [data_shape, (data_shape[0], num_hidden)] + attrs = [] + excluded_attrs = ['with_sum'] + check_neg_fusion_quantized(net, attrs, excluded_attrs, data_shapes, name='fc')