diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h index d40abaf1fd66..f6149ea30215 100644 --- a/src/operator/nn/convolution-inl.h +++ b/src/operator/nn/convolution-inl.h @@ -53,7 +53,7 @@ enum ConvolutionOpInputs {kData, kWeight, kBias}; enum ConvolutionOpOutputs {kOut}; enum ConvolutionOpResource {kTempSpace}; enum ConvolutionOpCudnnTune {kOff, kLimited, kFastest}; -} +} // namespace conv struct ConvolutionParam : public dmlc::Parameter { TShape kernel; @@ -129,6 +129,10 @@ void ConvolutionParamParser(nnvm::NodeAttrs* attrs); typedef ParamOpSign ConvSignature; +static inline size_t GetInShapeSize(const ConvolutionParam ¶m_) { + return 2 + (param_.no_bias ? 0 : 1); +} + } // namespace op } // namespace mxnet @@ -176,8 +180,7 @@ class ConvolutionOp { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(req[conv::kOut], kWriteTo); - size_t expected = param_.no_bias ? 2 : 3; - CHECK_EQ(in_data.size(), expected); + CHECK_EQ(in_data.size(), GetInShapeSize(param_)); CHECK_EQ(out_data.size(), 1U); CHECK_EQ(req[conv::kOut], kWriteTo); LayerSetUp(in_data[conv::kData].shape_, out_data[conv::kOut].shape_); diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index ef70ccd6ec1e..b53902ff7c4e 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -28,6 +28,7 @@ #include "../elemwise_op_common.h" #include "./mkldnn/mkldnn_ops-inl.h" #include "./mkldnn/mkldnn_base-inl.h" +#include "./mkldnn/mkldnn_convolution-inl.h" #if MXNET_USE_NNPACK == 1 #include "../nnpack/nnpack_pooling-inl.h" #endif // MXNET_USE_NNPACK @@ -41,11 +42,19 @@ static inline index_t AddPad(index_t dsize, index_t pad) { } static inline std::vector ListArguments(const ConvolutionParam& param_) { - if (!param_.no_bias) { + if (!param_.no_bias) return {"data", "weight", "bias"}; - } else { + else return {"data", "weight"}; +} + +static inline std::string PrintArguments(const ConvolutionParam& param_) { + auto args = ListArguments(param_); + std::string str = "["; + for (const auto &arg : args) { + str += arg + ", "; } + return str.substr(0, str.size() - 2) + "]"; } #if MXNET_USE_MKLDNN == 1 @@ -85,11 +94,7 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, std::vector *out_shape) { using namespace mshadow; const ConvolutionParam& param_ = nnvm::get(attrs.parsed); - if (!param_.no_bias) { - CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]"; - } else { - CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; - } + CHECK_EQ(in_shape->size(), GetInShapeSize(param_)) << "Input:" << PrintArguments(param_); // CHECK_EQ(out_shape->size(), 1) << "Output: [output]"; out_shape->resize(1, TShape()); const TShape &dshp = (*in_shape)[conv::kData]; @@ -294,7 +299,7 @@ inline static bool ConvStorageType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { const ConvolutionParam& param = nnvm::get(attrs.parsed); - uint32_t in_expected = param.no_bias ? 2 : 3; + uint32_t in_expected = GetInShapeSize(param); CHECK_EQ(in_attrs->size(), in_expected); CHECK_EQ(out_attrs->size(), 1); @@ -470,17 +475,14 @@ There are other options to tune the performance. )code" ADD_FILELINE) .set_num_inputs([](const NodeAttrs& attrs) { const ConvolutionParam& params = nnvm::get(attrs.parsed); - return params.no_bias ? 2 : 3; + return GetInShapeSize(params); }) .set_num_outputs(1) .set_attr_parser(ConvolutionParamParser) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { const ConvolutionParam& params = nnvm::get(attrs.parsed); - if (params.no_bias) - return std::vector{"data", "weight"}; - else - return std::vector{"data", "weight", "bias"}; + return ListArguments(params); }) .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index 23f2fe694633..c61d7ed05397 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -35,19 +35,48 @@ namespace mxnet { namespace op { +struct ConvFusionParam : public dmlc::Parameter { + // When adding more members into this clss, please double check GetHash() + // won't overflow. + bool with_bn; + bool with_relu; + bool with_sum; + bool with_postsum_relu; + DMLC_DECLARE_PARAMETER(ConvFusionParam) { + DMLC_DECLARE_FIELD(with_bn).set_default(false) + .describe("Add post batchnorm."); + DMLC_DECLARE_FIELD(with_relu).set_default(false) + .describe("Add post relu"); + DMLC_DECLARE_FIELD(with_sum).set_default(false) + .describe("Add post sum"); + DMLC_DECLARE_FIELD(with_postsum_relu).set_default(false) + .describe("Add post relu after sum"); + } + const int GetHash() const { + int hash = 0; + hash = hash * 2 + this->with_bn ? 1 : 0; + hash = hash * 2 + this->with_relu ? 1 : 0; + hash = hash * 2 + this->with_sum ? 1 : 0; + hash = hash * 2 + this->with_postsum_relu ? 1 : 0; + return hash; + } +}; + mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( - const ConvolutionParam& param, const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output); + const ConvolutionParam ¶m, const ConvFusionParam &fusion_param, + const bool is_train, const NDArray &data, const NDArray &weights, + const NDArray *bias, const NDArray &output); class MKLDNNConvForward { public: mkldnn::convolution_forward::primitive_desc fwd_pd; - MKLDNNConvForward(const ConvolutionParam& param, const bool is_train, + MKLDNNConvForward(const ConvolutionParam ¶m, + const ConvFusionParam &fusion_param, const bool is_train, const NDArray &data, const NDArray &weights, - const NDArray *bias, const NDArray &output): fwd_pd( - GetConvFwdImpl(param, is_train, data, weights, bias, output)) { - } + const NDArray *bias, const NDArray &output) + : fwd_pd(GetConvFwdImpl(param, fusion_param, is_train, data, weights, + bias, output)) {} void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, const mkldnn::memory *bias, const mkldnn::memory &output); diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index cf04ea8da3d7..83cde062692f 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -34,6 +34,8 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(ConvFusionParam); + bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { if (params.kernel.ndim() != 2) return false; @@ -41,8 +43,9 @@ bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { } mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( - const ConvolutionParam& param, const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output) { + const ConvolutionParam ¶m, const ConvFusionParam &fusion_param, + const bool is_train, const NDArray &data, const NDArray &weights, + const NDArray *bias, const NDArray &output) { auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; auto data_md = GetMemDesc(data); auto weight_md = GetWeightDesc(weights, param.num_group); @@ -57,16 +60,36 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( mkldnn::memory::dims padding{0, 0}; padding[0] = param.pad[0]; padding[1] = param.pad[1]; + mkldnn::primitive_attr attr; + mkldnn::post_ops ops; + if (fusion_param.with_relu) { + float scale = 1.0f; // for fp32, scale is 1. + float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. + float beta = 1.0f; // ignored for mkldnn_eltwise_relu. + ops.append_eltwise(scale, eltwise_relu, alpha, beta); + + } + if (fusion_param.with_sum) { + float scale = 1.0f; + ops.append_sum(scale); + } + if (fusion_param.with_postsum_relu) { + float scale = 1.0f; // for fp32, scale is 1. + float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. + float beta = 1.0f; // ignored for mkldnn_eltwise_relu. + ops.append_eltwise(scale, eltwise_relu, alpha, beta); + } + attr.set_post_ops(ops); if (param.dilate.ndim() == 0 && bias == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, engine); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); } else if (param.dilate.ndim() == 0) { auto bias_md = GetMemDesc(*bias); mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, bias_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, engine); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); } else { mkldnn::memory::dims dilates{0, 0}; dilates[0] = param.dilate[0] - 1; @@ -75,14 +98,14 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, dilates, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, engine); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); } else { auto bias_md = GetMemDesc(*bias); mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, bias_md, out_md, strides, dilates, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, engine); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); } } } @@ -207,16 +230,20 @@ void MKLDNNConvForward::SetNewMem(const mkldnn::memory &data, } } -MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs, const bool is_train, - const NDArray &data, const NDArray &weights, - const NDArray *bias, const NDArray &output) { +MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs &attrs, + const bool is_train, const NDArray &data, + const NDArray &weights, const NDArray *bias, + const NDArray &output) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; #else static MX_THREAD_LOCAL std::unordered_map fwds; #endif const ConvolutionParam& param = nnvm::get(attrs.parsed); + ConvFusionParam fusion_param; + fusion_param.Init(attrs.dict, dmlc::parameter::kAllowUnknown); MKLDNNConvSignature key(param); + key.AddSign(fusion_param.GetHash()); key.AddSign(is_train); // Here we can sign the conv op with NDArray because conv primitive will // decide the right layout for the, so we only need to get the shape and the @@ -227,9 +254,10 @@ MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs, const bool is_train, if (bias) key.AddSign(*bias); + auto it = fwds.find(key); if (it == fwds.end()) { - MKLDNNConvForward fwd(param, is_train, data, weights, bias, output); + MKLDNNConvForward fwd(param, fusion_param, is_train, data, weights, bias, output); auto ins_ret = fwds.insert( std::pair(key, fwd)); CHECK(ins_ret.second); @@ -238,15 +266,20 @@ MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs, const bool is_train, return it->second; } -void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data) { +void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); - const ConvolutionParam& param = nnvm::get(attrs.parsed); + const ConvolutionParam ¶m = nnvm::get(attrs.parsed); + ConvFusionParam fusion_param; + fusion_param.Init(attrs.dict, dmlc::parameter::kAllowUnknown); NDArray weight = in_data[conv::kWeight]; - MKLDNNConvForward &fwd = GetConvFwd(attrs, ctx.is_train, in_data[conv::kData], weight, - param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); + bool no_bias = param.no_bias && !fusion_param.with_bn; + MKLDNNConvForward &fwd = GetConvFwd( + attrs, ctx.is_train, in_data[conv::kData], weight, + no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc()); const mkldnn::memory *weight_mem; @@ -271,11 +304,21 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc()); } } - auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(), - req[conv::kOut]); + mkldnn_output_t out_mem; + if (fusion_param.with_sum) { + out_mem = mkldnn_output_t( + OutDataOp::Noop, + const_cast(out_data[conv::kOut].GetMKLDNNDataReorder( + fwd.fwd_pd.dst_primitive_desc()))); + } else { + out_mem = CreateMKLDNNMem(out_data[conv::kOut], + fwd.fwd_pd.dst_primitive_desc(), req[conv::kOut]); + } + const mkldnn::memory *bias_mem = nullptr; - if (!param.no_bias) - bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc()); + if (!no_bias) { + bias_mem = in_data[conv::kBias].GetMKLDNNData(); + } fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); @@ -290,8 +333,11 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); const std::vector &in_grad = outputs; const ConvolutionParam& param = nnvm::get(attrs.parsed); - mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl(param, ctx.is_train, - inputs[conv::kData + 1], inputs[conv::kWeight + 1], + ConvFusionParam fusion_param; + fusion_param.Init(attrs.dict, dmlc::parameter::kAllowUnknown); + mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl( + param, fusion_param, ctx.is_train, inputs[conv::kData + 1], + inputs[conv::kWeight + 1], param.no_bias ? nullptr : &inputs[conv::kBias + 1], inputs[conv::kOut]); CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace"; diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc new file mode 100644 index 000000000000..8dd1d8cb26f9 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -0,0 +1,325 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +#if MXNET_USE_MKLDNN == 1 +#include +#include +#include "./mkldnn_conv.h" +#include "../../nn/mkldnn/mkldnn_ops-inl.h" +#include "../../../imperative/imperative_utils.h" +#include "../../../imperative/cached_op.h" +#include "../../nn/convolution-inl.h" +#include "../../nn/batch_norm-inl.h" +namespace mxnet { +namespace op { + +#define SUBGRAPH_DEBUG 0 + +template +static void UpdateConvWeightBias(const NDArray &weight, const NDArray *bias, + const NDArray &gamma, const NDArray &beta, + const NDArray &mean, + const NDArray &variance, + std::shared_ptr update_weight, + std::shared_ptr update_bias, + const BatchNormParam ¶m) { +#if SUBGRAPH_DEBUG + printf("input weight: %f %f %f %f \n", weight.data().dptr()[0], + weight.data().dptr()[1], + weight.data().dptr()[2], + weight.data().dptr()[3]); + printf("bn param eps: %f \n", param.eps); + printf("bn param fix_gamma: %d \n", param.fix_gamma); + printf("bn param use_global_stats: %d \n", param.use_global_stats); + printf("bn param output_mean_var: %d \n", param.output_mean_var); + printf("bn param axis: %d \n", param.axis); +#endif + DType *weight_ptr = weight.Reorder2Default().data().dptr(); + DType *bias_ptr = bias ? bias->Reorder2Default().data().dptr() : nullptr; + DType *gamma_ptr = gamma.Reorder2Default().data().dptr(); + DType *beta_ptr = beta.Reorder2Default().data().dptr(); + DType *mean_ptr = mean.Reorder2Default().data().dptr(); + DType *var_ptr = variance.Reorder2Default().data().dptr(); + DType *update_weight_ptr = update_weight->data().dptr(); + DType *update_bias_ptr = update_bias->data().dptr(); + size_t channel = gamma.shape()[0]; + size_t offset = weight.shape()[1] * weight.shape()[2] * weight.shape()[3]; +#pragma omp parallel for + for (size_t c = 0; c < channel; ++c) { + DType *p1 = reinterpret_cast(weight_ptr + c * offset); + DType *p2 = reinterpret_cast(update_weight_ptr + c * offset); + DType alpha = (param.fix_gamma ? static_cast(1.0f) : gamma_ptr[c]) / + sqrt(var_ptr[c] + param.eps); + + if (bias_ptr) + update_bias_ptr[c] = beta_ptr[c] + alpha * (bias_ptr[c] - mean_ptr[c]); + else + update_bias_ptr[c] = beta_ptr[c] - alpha * mean_ptr[c]; + + for (size_t k = 0; k < offset; ++k) { + p2[k] = p1[k] * alpha; + } + } +#if SUBGRAPH_DEBUG + printf("update weight: %f %f %f %f \n", update_weight->data().dptr()[0], + update_weight->data().dptr()[1], + update_weight->data().dptr()[2], + update_weight->data().dptr()[3]); +#endif +} + +static void ConvFusionFallBackCompute() { + LOG(FATAL) << "Don't know how to do ConvFusionFallBackCompute!"; +} + +static void ConvolutionFusionComputeExCPU(const nnvm::NodeAttrs &conv_attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const ConvolutionParam ¶ms = + nnvm::get(conv_attrs.parsed); + if (SupportMKLDNNConv(params, inputs[0])) { + // MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); + MKLDNNConvolutionForward(conv_attrs, ctx, inputs, req, outputs); + // MKLDNN_OPCHECK_RUN(ConvolutionCompute, attrs, ctx, inputs, req, + // outputs); + return; + } + ConvFusionFallBackCompute(); +} + +class SgMKLDNNConvOperator { + public: + explicit SgMKLDNNConvOperator(const nnvm::NodeAttrs &attrs) + : subgraph_sym_(nnvm::get(attrs.parsed)), + // subgraph_exec_(nullptr), + cached_weight_(nullptr), + cached_bias_(nullptr), + bn_attrs_(nullptr), + conv_attrs_(nullptr), + in_sum_at_begin(false), + with_bn(false), + with_relu(false), + with_sum(false), + with_postsum_relu(false) { + // subgraph_exec_.reset(new CachedOp(subgraph_sym_, {{"static_alloc", "true"}})); + auto it = attrs.dict.find("in_sum_at_begin"); + if (it != attrs.dict.end()) + in_sum_at_begin = (it->second == "true"); + it = attrs.dict.find("with_bn"); + if (it != attrs.dict.end()) + with_bn = (it->second == "true"); + it = attrs.dict.find("with_relu"); + if (it != attrs.dict.end()) + with_relu = (it->second == "true"); + it = attrs.dict.find("with_sum"); + if (it != attrs.dict.end()) + with_sum = (it->second == "true"); + it = attrs.dict.find("with_postsum_relu"); + if (it != attrs.dict.end()) + with_postsum_relu = (it->second == "true"); + + DFSVisit(subgraph_sym_.outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + auto &node_name = node->op()->name; + if (node_name == "BatchNorm") { + CHECK(bn_attrs_.get() == nullptr); + CHECK_EQ(with_bn, true); + bn_attrs_ = std::make_shared(node->attrs); + } else if (node_name == "Convolution") { + CHECK(conv_attrs_.get() == nullptr); + conv_attrs_ = std::make_shared(node->attrs); + } + }); + CHECK(conv_attrs_.get()); + conv_attrs_->dict["with_bn"] = with_bn ? "true" : "false"; + conv_attrs_->dict["with_relu"] = with_relu ? "true" : "false"; + conv_attrs_->dict["with_sum"] = with_sum ? "true" : "false"; + conv_attrs_->dict["with_postsum_relu"] = with_postsum_relu ? "true" : "false"; + } + + void Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + void Backward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + LOG(FATAL) << "Not implemented: subgraph mkldnn Conv only supports inference computation"; + } + + private: + nnvm::Symbol subgraph_sym_; + // CachedOpPtr subgraph_exec_; // Used for fallback compute + std::shared_ptr cached_weight_; + std::shared_ptr cached_bias_; + std::shared_ptr bn_attrs_; + std::shared_ptr conv_attrs_; + bool in_sum_at_begin; + bool with_bn; + bool with_relu; + bool with_sum; + bool with_postsum_relu; +}; + +void SgMKLDNNConvOperator::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const ConvolutionParam &conv_params = nnvm::get(conv_attrs_->parsed); +#if SUBGRAPH_DEBUG + LOG(INFO) << "Conv inputs size: " << inputs.size(); + LOG(INFO) << "Conv outputs size: " << outputs.size(); + LOG(INFO) << "Conv req size: " << req.size(); + for (size_t k = 0; k < inputs.size(); ++k) { + auto input = inputs[k]; + printf("input %ld :", k); + for (size_t i = 0; i < input.shape().ndim(); ++i) { + printf("%ld ", input.shape()[i]); + } + printf("\n"); + } + CHECK_EQ(ctx.is_train, false); + printf("output:"); + for (size_t i = 0; i < outputs[0].shape().ndim(); ++i) { + printf("%ld ", outputs[0].shape()[i]); + } + printf("\n"); +#endif + size_t input_size = 2 + (conv_params.no_bias ? 0 : 1) + (with_bn ? 4 : 0) + + (with_sum ? 1 : 0); + CHECK_EQ(inputs.size(), input_size); + size_t idx = 0; + auto in_sum = in_sum_at_begin ? (idx++) : 0; + auto in_data = idx++; + auto in_weight = idx++; + auto in_bias = conv_params.no_bias ? 0 : (idx++); + auto in_gamma = with_bn ? (idx++) : 0; + auto in_beta = with_bn ? (idx++) : 0; + auto in_mean = with_bn ? (idx++) : 0; + auto in_var = with_bn ? (idx++) : 0; + in_sum = ((!in_sum_at_begin) && with_sum) ? (idx++) : 0; + auto output = outputs[0]; + CHECK_EQ(input_size, idx); + + if (with_bn && (nullptr == cached_weight_ || nullptr == cached_bias_)) { + CHECK_EQ(inputs[in_weight].dtype(), inputs[in_gamma].dtype()); + CHECK_EQ(inputs[in_weight].dtype(), inputs[in_beta].dtype()); + CHECK_EQ(inputs[in_weight].dtype(), inputs[in_var].dtype()); + const BatchNormParam &bn_param = + nnvm::get(bn_attrs_->parsed); + cached_weight_ = std::make_shared( + inputs[in_weight].storage_type(), inputs[in_weight].shape(), + inputs[in_weight].ctx(), true, inputs[in_weight].dtype()); + cached_bias_ = std::make_shared( + inputs[in_beta].storage_type(), inputs[in_beta].shape(), + inputs[in_beta].ctx(), true, inputs[in_beta].dtype()); + MSHADOW_REAL_TYPE_SWITCH(inputs[in_weight].dtype(), DType, { + UpdateConvWeightBias( + inputs[in_weight], conv_params.no_bias ? nullptr : &inputs[in_bias], + inputs[in_gamma], inputs[in_beta], inputs[in_mean], inputs[in_var], + cached_weight_, cached_bias_, bn_param); + }); + } + std::vector new_inputs; + std::vector new_req; + std::vector new_outputs; + if (with_bn) { + new_inputs = {inputs[in_data], *cached_weight_, *cached_bias_}; + new_req = {req[in_data], req[in_weight], req[in_beta]}; + } else { + if (conv_params.no_bias) { + new_inputs = {inputs[in_data], inputs[in_weight]}; + new_req = {req[in_data], req[in_weight]}; + } else { + new_inputs = {inputs[in_data], inputs[in_weight], inputs[in_bias]}; + new_req = {req[in_data], req[in_weight], req[in_bias]}; + } + } + if (with_sum) + new_outputs = {inputs[in_sum]}; + else + new_outputs = {output}; + ConvolutionFusionComputeExCPU(*conv_attrs_, ctx, new_inputs, new_req, + new_outputs); + } + + OpStatePtr CreateSgMKLDNNConvOpState(const nnvm::NodeAttrs &attrs, Context ctx, + const std::vector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); + } + + void SgMKLDNNConvOpForward(const OpStatePtr &state_ptr, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + SgMKLDNNConvOperator &op = state_ptr.get_state(); + op.Forward(ctx, inputs, req, outputs); + } + +NNVM_REGISTER_OP(_sg_mkldnn_conv) +.describe(R"code(_sg_mkldnn_conv)code" ADD_FILELINE) +.set_num_inputs(DefaultSubgraphOpNumInputs) +.set_num_outputs(DefaultSubgraphOpNumOutputs) +.set_attr("FListInputNames", + DefaultSubgraphOpListInputs) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FCreateOpState", CreateSgMKLDNNConvOpState) +.set_attr("FInferShape", DefaultSubgraphOpShape) +.set_attr("FInferType", DefaultSubgraphOpType) +.set_attr("FInferStorageType", + DefaultSubgraphOpStorageType) +.set_attr("FStatefulComputeEx", + SgMKLDNNConvOpForward) +.set_attr("FMutateInputs", + DefaultSubgraphOpMutableInputs) +.set_attr("FResourceRequest", + DefaultSubgraphOpResourceRequest) +.set_attr("key_var_num_args", "num_args") +.set_attr("FInplaceOption", [](const nnvm::NodeAttrs + &attrs) { + auto it = attrs.dict.find("with_sum"); + if (it != attrs.dict.end() && it->second == "true") { + it = attrs.dict.find("in_sum_at_begin"); + if (it != attrs.dict.end() && it->second == "true") { + return std::vector>{std::pair{0, 0}}; + } else { + it = attrs.dict.find("no_bias"); + CHECK(it != attrs.dict.end()); + bool no_bias = it->second == "true"; + it = attrs.dict.find("with_bn"); + bool with_bn = (it != attrs.dict.end()) ? it->second == "true" : false; + int idx = 2 + (no_bias ? 0 : 1) + (with_bn ? 4 : 0); + return std::vector>{std::pair{idx, 0}}; + } + } else { + return std::vector>(); + } +}); +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.h b/src/operator/subgraph/mkldnn/mkldnn_conv.h new file mode 100644 index 000000000000..42dbf6470307 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.h @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_CONV_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_CONV_H_ + +#if MXNET_USE_MKLDNN == 1 + +#include "../common.h" +#include "../subgraph_property.h" +#include "../../nn/convolution-inl.h" +#include "../../nn/activation-inl.h" + +namespace mxnet { +namespace op { +namespace sg { + +class SgMKLDNNConvSelector : public SubgraphSelector { + public: + /*! \brief pattern match status */ + enum SelectStatus { + sFail = 0, + sStart, + sBN, + sSum, + sSuccess, + }; + + private: + bool disable_conv_bn; + bool disable_conv_relu; + bool disable_conv_sum; + bool disable_all; + SelectStatus status; + nnvm::NodeEntry conv_data; + std::vector matched_list; + + bool HandleMatchStatus() { + if (matched_list.size() > 1) { + status = sSuccess; + } else { + status = sFail; + } + return false; + } + + public: + SgMKLDNNConvSelector(int dis_conv_bn, int dis_conv_relu, int dis_conv_sum) + : disable_conv_bn(dis_conv_bn), + disable_conv_relu(dis_conv_relu), + disable_conv_sum(dis_conv_sum), + disable_all(disable_conv_bn && disable_conv_relu && disable_conv_sum) {} + + virtual bool Select(const nnvm::Node &n) override { + bool match = + (!disable_all) && (!n.is_variable()) && (n.op()->name == "Convolution"); + if (match) { + status = sStart; + conv_data = n.inputs[0]; + matched_list.clear(); + matched_list.push_back(&n); + return true; + } + return false; + } + + virtual bool SelectInput(const nnvm::Node &n, + const nnvm::Node &new_node) override { + return false; + } + + virtual bool SelectOutput(const nnvm::Node &n, + const nnvm::Node &new_node) override { + if (status == sFail || status == sSuccess || new_node.is_variable()) + return false; + // If n isn't the last matched node, then we encoutered a internal + // branch, we should pop out the node behind n and stop fusion. + if (matched_list.back() != &n) { + while (matched_list.back() != &n) { + matched_list.pop_back(); + } + // If the remaining node is more than 1, then we can still do fusion. + return HandleMatchStatus(); + } + // Use status machine to do selection. The status change is + // sStart -> sBN -> sSum -> sSuccess + switch (status) { + case sStart: + if ((!disable_conv_bn) && new_node.op()->name == "BatchNorm") { + matched_list.push_back(&new_node); + status = sBN; + return true; + } + case sBN: + if ((!disable_conv_sum) && new_node.op()->name == "elemwise_add") { + // Make sure n is the left operand of sum, if not, + // switch sum operands sequence to ensure that + // the extra sum operand stays in the last of inputs. + auto sum_entry = new_node.inputs[1]; + if (new_node.inputs[1].node.get() == &n) { + sum_entry = new_node.inputs[0]; + } + #if 0 + if (sum_entry.node == conv_data.node && + sum_entry.index == conv_data.index) { + // At this situation, we faced a structure like, + // data -> conv -> sum + // \---------/ + // As conv+sum is a inplace operating, sum's output + // will override data, which is not supported. + return HandleMatchStatus(); + } + #endif + matched_list.push_back(&new_node); + status = sSum; + return true; + } + case sSum: + default: + if ((!disable_conv_relu) && new_node.op()->name == "Activation") { + const ActivationParam ¶m = + nnvm::get(new_node.attrs.parsed); + if (param.act_type == activation::kReLU) { + matched_list.push_back(&new_node); + // If we find conv+relu, then we can't match bn anymore. + if (status == sStart) status = sBN; + return true; + } else { + return HandleMatchStatus(); + } + } + return HandleMatchStatus(); + } + } + + virtual std::vector Filter( + const std::vector &candidates) override { + if (status == sFail || candidates.size() <= 1) { + return std::vector(0); + } else { + return candidates; + } + } +}; + +class SgMKLDNNConvProperty : public SubgraphProperty { + public: + SgMKLDNNConvProperty() { + int disable_all = dmlc::GetEnv("MXNET_DISABLE_FUSION_ALL", 0); + disable_conv_bn = dmlc::GetEnv("MXNET_DISABLE_FUSION_CONV_BN", 0); + disable_conv_relu = dmlc::GetEnv("MXNET_DISABLE_FUSION_CONV_RELU", 0); + disable_conv_sum = dmlc::GetEnv("MXNET_DISABLE_FUSION_CONV_SUM", 0); + + if (disable_all || + (disable_conv_bn && disable_conv_relu && disable_conv_sum)) { + LOG(INFO) << "MKLDNN Convolution fusion pass is disabled. Fusion " + "configurations: "; + } else { + LOG(INFO) << "Start to execute MKLDNN Convolution fusion pass. Fusion " + "configurations:"; + } + LOG(INFO) << "MXNET_DISABLE_FUSION_ALL=" << disable_all; + LOG(INFO) << "MXNET_DISABLE_FUSION_CONV_BN=" << disable_conv_bn; + LOG(INFO) << "MXNET_DISABLE_FUSION_CONV_RELU=" << disable_conv_relu; + LOG(INFO) << "MXNET_DISABLE_FUSION_CONV_SUM=" << disable_conv_sum; + if (disable_all) { + disable_conv_bn = 1; + disable_conv_relu = 1; + disable_conv_sum = 1; + } + } + static SubgraphPropertyPtr Create() { + return std::make_shared(); + } + nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::NodePtr n = nnvm::Node::Create(); + // Initialize new attributes to false + n->attrs.dict["in_sum_at_begin"] = "false"; + n->attrs.dict["no_bias"] = "false"; + n->attrs.dict["with_bn"] = "false"; + n->attrs.dict["with_sum"] = "false"; + n->attrs.dict["with_relu"] = "false"; + n->attrs.dict["with_postsum_relu"] = "false"; + // This op has single output, remove duplicated. + auto last_node = sym.outputs[0].node; + nnvm::Symbol new_sym; + new_sym.outputs.emplace_back(nnvm::NodeEntry{last_node, 0, 0}); + std::string node_name = ""; + bool _with_sum = false; + std::unordered_set node_sets; + DFSVisit(new_sym.outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + node_sets.insert(node.get()); + auto &sub_name = node->op()->name; + if (sub_name == "Convolution") { + node_name += "Conv_"; + const ConvolutionParam &conv_params = + nnvm::get(node->attrs.parsed); + n->attrs.dict["no_bias"] = conv_params.no_bias ? "true" : "false"; + } else if (sub_name == "BatchNorm") { + node_name += "BN_"; + n->attrs.dict["with_bn"] = "true"; + } else if (sub_name == "elemwise_add") { + node_name += "Add_"; + n->attrs.dict["with_sum"] = "true"; + _with_sum = true; + if (node_sets.count(node->inputs[1].node.get())) { + n->attrs.dict["in_sum_at_begin"] = "true"; + } else { + CHECK_NE(node_sets.count(node->inputs[0].node.get()), 0U); + } + } else if (sub_name == "Activation") { + node_name += "Relu_"; + if (!_with_sum) { + n->attrs.dict["with_relu"] = "true"; + } else { + n->attrs.dict["with_postsum_relu"] = "true"; + } + } + }); + + n->attrs.name = "sg_mkldnn_" + node_name + std::to_string(subgraph_id); + n->attrs.op = Op::Get("_sg_mkldnn_conv"); + CHECK(n->attrs.op); + n->attrs.parsed = new_sym; + return n; + } + + virtual SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = std::make_shared( + disable_conv_bn, disable_conv_relu, disable_conv_sum); + return selector; + } + + virtual void ConnectSubgraphOutput( + const nnvm::NodePtr n, + std::vector &output_entries) const override { + // Connect all extern output entries to output[0] + for (size_t i = 0; i < output_entries.size(); ++i) { + *output_entries[i] = nnvm::NodeEntry{n, 0, 0}; + } + } + + private: + int disable_conv_bn; + int disable_conv_relu; + int disable_conv_sum; +}; + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); + +} // namespace sg +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_CONV_H_ diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc index e8c3069255c3..7c2f392b9288 100644 --- a/src/operator/subgraph/partition_graph.cc +++ b/src/operator/subgraph/partition_graph.cc @@ -630,11 +630,8 @@ void CreateSubgraphNode(Graph* g, } const SubgraphPropertyPtr& subg_prop = g->GetAttr("subgraph_property"); nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_id); - // Connect the external nodes to the subgraph node. - for (size_t i = 0; i < output_entries.size(); ++i) { - *output_entries[i] = nnvm::NodeEntry{n, static_cast(i), 0}; - } + subg_prop->ConnectSubgraphOutput(n, output_entries); n->inputs = orig_input_entries; const auto& indexed_graph = g->indexed_graph(); for (size_t i = 0; i < n->inputs.size(); ++i) { diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index 2153a366471a..9bf4d4aa351c 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -67,6 +67,14 @@ class SubgraphProperty { // execute the operators in the subgraph. virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s, const int subgraph_id = 0) const = 0; + // Connect subgraph internal output with external output entries. By default, + // each output entry will connect to an unique internal output. + virtual void ConnectSubgraphOutput(const nnvm::NodePtr n, + std::vector& output_entries) const { + for (size_t i = 0; i < output_entries.size(); ++i) { + *output_entries[i] = nnvm::NodeEntry{n, static_cast(i), 0}; + } + } // set an attr with name in the attr map template SubgraphProperty& SetAttr(const std::string& name, const T& value) {