From 786087a9b581e1a9c92eef6ba88355d9d699d0eb Mon Sep 17 00:00:00 2001 From: Bartosz Kuncer Date: Tue, 17 Aug 2021 10:52:25 +0200 Subject: [PATCH 1/4] [operator] Integrate oneDNN layer normalization implementation --- src/operator/nn/layer_norm-inl.h | 26 +- src/operator/nn/layer_norm.cc | 66 +++++ src/operator/nn/layer_norm.cu | 4 + src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 + .../nn/mkldnn/mkldnn_layer_norm-inl.h | 102 +++++++ src/operator/nn/mkldnn/mkldnn_layer_norm.cc | 260 ++++++++++++++++++ src/operator/nn/mkldnn/mkldnn_ops-inl.h | 12 + 7 files changed, 471 insertions(+), 1 deletion(-) create mode 100644 src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h create mode 100644 src/operator/nn/mkldnn/mkldnn_layer_norm.cc diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index 79d09063ee6c..b55ef292cb9d 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -45,7 +45,9 @@ namespace op { namespace layernorm { enum LayerNormOpInputs {kData, kGamma, kBeta}; // kGamma: scaling parameters, kBeta: shift biases -enum LayerNormOpOutputs {kOut, kMean, kStd}; // req, out_data +enum LayerNormOpOutputs {kOut, kMean, kStd}; // indices for req, out_data +enum LayerNormOpInputsBwd {kBwdOutGrad, kBwdData, kBwdGamma, kBwdMean, kBwdStd, kBwdBeta}; +enum LayerNormOpOutputsBwd {kBwdDataGrad, kBwdGammaGrad, kBwdBetaGrad}; } // namespace layernorm struct LayerNormParam : public dmlc::Parameter { @@ -71,6 +73,11 @@ struct LayerNormParam : public dmlc::Parameter { (*dict)["eps"] = eps_s.str(); (*dict)["output_mean_var"] = output_mean_var_s.str(); } + + bool operator==(const LayerNormParam& other) const { + return (this->axis == other.axis && this->eps == other.eps && + this->output_mean_var == other.output_mean_var); + } }; inline int GetRealAxis(int axis, int ndim) { @@ -257,7 +264,11 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; +#if MXNET_USE_ONEDNN == 1 + CHECK_EQ(inputs.size(), 6U); // additional beta tensor +#else CHECK_EQ(inputs.size(), 5U); +#endif const LayerNormParam& param = nnvm::get(attrs.parsed); int axis = param.axis; if (axis < 0) { @@ -313,4 +324,17 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet + +namespace std { +template <> +struct hash { + size_t operator()(const mxnet::op::LayerNormParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.axis); + ret = dmlc::HashCombine(ret, val.eps); + ret = dmlc::HashCombine(ret, val.output_mean_var); + return ret; + } +}; +} // namespace std #endif // MXNET_OPERATOR_NN_LAYER_NORM_INL_H_ diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index 1a040fa6f7d0..4e8a80e74ad8 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -57,6 +57,10 @@ #include "layer_norm-inl.h" #include #include "../elemwise_op_common.h" +#if MXNET_USE_ONEDNN == 1 +#include "./mkldnn/mkldnn_base-inl.h" +#include "./mkldnn/mkldnn_ops-inl.h" +#endif // MXNET_USE_ONEDNN #if MSHADOW_USE_MKL == 1 #include "../mkl_functions-inl.h" @@ -392,6 +396,50 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, return LayerNormGradComputeGeneral(attrs, ctx, inputs, req, outputs); } +#if MXNET_USE_ONEDNN == 1 +static bool LayerNormInferStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK(!in_attrs->empty()); + + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +} + +static void LayerNormComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const LayerNormParam& param = nnvm::get(attrs.parsed); + if (SupportMKLDNNLayerNorm(param, inputs)) { + MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); + MKLDNNRun(MKLDNNLayerNormForward, attrs, ctx, inputs, req, outputs); + MKLDNN_OPCHECK_RUN(LayerNormCompute, attrs, ctx, inputs, req, outputs); + return; + } else { + FallBackCompute(LayerNormCompute, attrs, ctx, inputs, req, outputs); + } +} + +static void LayerNormGradComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const LayerNormParam& param = nnvm::get(attrs.parsed); + if (SupportMKLDNNLayerNorm(param, inputs)) { + MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); + MKLDNNRun(MKLDNNLayerNormBackward, attrs, ctx, inputs, req, outputs); + MKLDNN_OPCHECK_RUN(LayerNormGradCompute, attrs, ctx, inputs, req, outputs); + return; + } else { + FallBackCompute(LayerNormGradCompute, attrs, ctx, inputs, req, outputs); + } +} +#endif + NNVM_REGISTER_OP(LayerNorm) .add_alias("_npx_layer_norm") .describe(R"code(Layer normalization. @@ -439,6 +487,11 @@ axis to be the last item in the input shape. .set_attr("FInferShape", LayerNormShape) .set_attr("FInferType", ElemwiseType<3, 3>) .set_attr("FCompute", LayerNormCompute) +#if MXNET_USE_ONEDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FInferStorageType", LayerNormInferStorageType) +.set_attr("FComputeEx", LayerNormComputeExCPU) +#endif .set_attr("FGradient", [](const nnvm::ObjectPtr& n, const std::vector& ograds) { std::vector heads; @@ -447,6 +500,10 @@ axis to be the last item in the input shape. heads.push_back(n->inputs[1]); // gamma heads.emplace_back(n, 1, 0); // mean heads.emplace_back(n, 2, 0); // std +#if MXNET_USE_ONEDNN == 1 + heads.push_back(n->inputs[2]); // beta - needed for MKLDNN backward propagation; + // added at the end in case of fallback to non MKLDNN version +#endif return MakeGradNode("_backward_LayerNorm", n, heads, n->attrs.dict); }) .set_attr("FInplaceOption", @@ -464,11 +521,20 @@ axis to be the last item in the input shape. NNVM_REGISTER_OP(_backward_LayerNorm) +#if MXNET_USE_ONEDNN == 1 +.set_num_inputs(6) +#else .set_num_inputs(5) +#endif .set_num_outputs(3) .set_attr("TIsBackward", true) .set_attr_parser(ParamParser) .set_attr("FCompute", LayerNormGradCompute) +#if MXNET_USE_ONEDNN == 1 +.set_attr("FInferStorageType", LayerNormInferStorageType) +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", LayerNormGradComputeExCPU) +#endif .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }); diff --git a/src/operator/nn/layer_norm.cu b/src/operator/nn/layer_norm.cu index 9a33e0665ff4..1e719eda40ab 100644 --- a/src/operator/nn/layer_norm.cu +++ b/src/operator/nn/layer_norm.cu @@ -689,7 +689,11 @@ void LayerNormGradGPUContig(const LayerNormParam param, const std::vector& req, const std::vector& outputs) { using namespace mshadow; +#if MXNET_USE_ONEDNN == 1 + CHECK_EQ(inputs.size(), 6U); // additional beta tensor +#else CHECK_EQ(inputs.size(), 5U); +#endif const TBlob out_grad = inputs[0]; const TBlob in_data = inputs[1]; const TBlob gamma = inputs[2]; diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 2cef524aa4b7..2ee0793d3db2 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -202,6 +202,7 @@ struct SoftmaxParam; struct SoftmaxOutputParam; struct TransposeParam; struct ReshapeParam; +struct LayerNormParam; bool SupportMKLDNNAct(const ActivationParam& param); bool SupportMKLDNNAct(const ActivationParam& param, const NDArray& input); bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param); @@ -216,6 +217,7 @@ bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param, bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam& param); bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray& data); bool SupportMKLDNNBatchDot(const std::vector& inputs, const NDArray& output); +bool SupportMKLDNNLayerNorm(const LayerNormParam& param, const std::vector &inputs); } // namespace op static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h new file mode 100644 index 000000000000..e8938acfb1d8 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_layer_norm-inl.h + */ +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H_ + +#if MXNET_USE_ONEDNN == 1 + +#include +#include + +#include "../layer_norm-inl.h" +#include "./mkldnn_base-inl.h" +#include "./mkldnn_ops-inl.h" + +namespace mxnet { +namespace op { + +using layernorm_fwd_t = mkldnn::layer_normalization_forward; +using layernorm_fwd_pd_t = mkldnn::layer_normalization_forward::primitive_desc; + +using layernorm_bwd_t = mkldnn::layer_normalization_backward; +using layernorm_bwd_pd_t = mkldnn::layer_normalization_backward::primitive_desc; + +typedef ParamOpSign LayerNormSignature; + +class MKLDNNLayerNormFwd { + public: + static MKLDNNLayerNormFwd& GetCached(const LayerNormParam& param, + const OpContext& ctx, + const NDArray& data); + + MKLDNNLayerNormFwd(const LayerNormParam& param, const NDArray& data); + + static std::shared_ptr CreatePrimitiveDesc( + const LayerNormParam& param, + const mkldnn::memory::desc& src_md); + + void Execute(const LayerNormParam& param, + const OpContext& ctx, + const std::vector& inputs, + const OpReqType& req, + const std::vector& outputs) const; + + ~MKLDNNLayerNormFwd() {} + + private: + std::shared_ptr fwd; + std::shared_ptr fwd_pd; +}; + +class MKLDNNLayerNormBwd { + public: + static MKLDNNLayerNormBwd& GetCached(const LayerNormParam& param, + const std::vector& inputs); + + MKLDNNLayerNormBwd(const LayerNormParam& param, + const std::vector& inputs, + const mkldnn::memory::desc& data_md, + const mkldnn::memory::desc& diff_md); + + static std::shared_ptr CreatePrimitiveDesc( + const LayerNormParam& param, + const mkldnn::memory::desc& data_md, + const mkldnn::memory::desc& diff_md, + const layernorm_fwd_pd_t& layernorm_fwd_pd); + + void Execute(const std::vector& inputs, + const std::vector& outputs, + const std::vector& req) const; + + ~MKLDNNLayerNormBwd() {} + + private: + std::shared_ptr bwd; + std::shared_ptr fwd_pd; + std::shared_ptr bwd_pd; +}; + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H__ diff --git a/src/operator/nn/mkldnn/mkldnn_layer_norm.cc b/src/operator/nn/mkldnn/mkldnn_layer_norm.cc new file mode 100644 index 000000000000..e9133e5f3d8a --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_layer_norm.cc @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_layer_norm.cc + */ + +#if MXNET_USE_ONEDNN == 1 + +#include "./mkldnn_layer_norm-inl.h" + +namespace mxnet { +namespace op { + +bool SupportMKLDNNLayerNorm(const LayerNormParam& param, const std::vector& inputs) { + const mxnet::TShape& shape = inputs[layernorm::kData].shape(); + + // Native implementation (which can be found in function LayerNormCPU) is faster than oneDNN's one + // for small tensors. Below is the heuristic based on measurements on clx machine deciding whether + // the shape is better for oneDNN or native implementation. + auto ShapeBetterForMKLDNN = [](const mxnet::TShape& shape) { + constexpr size_t shapeLimit = 1024; + return shape.Size() / shape[0] >= shapeLimit && shape[0] >= shapeLimit; + }; + + return (ShapeBetterForMKLDNN(shape) && + (GetRealAxis(param.axis, shape.ndim()) == shape.ndim() - 1) && (shape.ndim() >= 2) && + (shape.ndim() <= 5) && + (inputs[layernorm::kData].dtype() == mshadow::kFloat32 || + inputs[layernorm::kData].dtype() == mshadow::kBfloat16) && + inputs[layernorm::kGamma].dtype() == mshadow::kFloat32 && + inputs[layernorm::kBeta].dtype() == mshadow::kFloat32); +} + +void MKLDNNLayerNormForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const LayerNormParam& param = nnvm::get(attrs.parsed); + const auto& fwd = MKLDNNLayerNormFwd::GetCached(param, ctx, inputs[layernorm::kData]); + fwd.Execute(param, ctx, inputs, req[layernorm::kOut], outputs); +} + +MKLDNNLayerNormFwd& MKLDNNLayerNormFwd::GetCached(const LayerNormParam& param, + const OpContext& ctx, + const NDArray& data) { + using layernorm_fwd_map = std::unordered_map; +#if DMLC_CXX11_THREAD_LOCAL + static thread_local layernorm_fwd_map layer_norm_fwds; +#else + static MX_THREAD_LOCAL layernorm_fwd_map layer_norm_fwds; +#endif + + LayerNormSignature key(param); + key.AddSign(data); + key.AddSign(param.eps); + + auto it = layer_norm_fwds.find(key); + if (it == layer_norm_fwds.end()) { + MKLDNNLayerNormFwd fwd(param, data); + it = AddToCache(&layer_norm_fwds, key, fwd); + } + return it->second; +} + +MKLDNNLayerNormFwd::MKLDNNLayerNormFwd(const LayerNormParam& param, const NDArray& data) { + const mkldnn::memory::desc data_md = data.GetMKLDNNData()->get_desc(); + fwd_pd = CreatePrimitiveDesc(param, data_md); + fwd = std::make_shared(*fwd_pd); +} + +std::shared_ptr MKLDNNLayerNormFwd::CreatePrimitiveDesc( + const LayerNormParam& param, + const mkldnn::memory::desc& src_md) { + layernorm_fwd_t::desc fwd_desc(mkldnn::prop_kind::forward_training, + src_md, + param.eps, + dnnl::normalization_flags::use_scale_shift); + mkldnn::engine& engine = CpuEngine::Get()->get_engine(); + return std::make_shared(fwd_desc, engine); +} + +inline mkldnn::memory::desc GetMeanVarDesc(const mkldnn::memory::data_type& dtype, + const mxnet::TShape& _shape) { + const auto ndim = _shape.ndim(); + + mkldnn::memory::dims shape(ndim, 1), strides(ndim, 1); + shape[0] = _shape[0]; + for (int i = ndim - 1; i > 0; --i) { + shape[i] = _shape[i]; + strides[i - 1] = strides[i] * shape[i]; + } + + return mkldnn::memory::desc{shape, dtype, strides}; +} + +inline mkldnn::memory GetScaleShiftMem(const NDArray& gamma, const NDArray& beta) { + // OneDNN takes gamma and beta as one SCALE_SHIFT tensor when both scale and shift are used. In + // mxnet scale is called gamma and shift is called beta. + constexpr size_t gammaAndBeta = 2; + CHECK_EQ(gamma.shape()[0], beta.shape()[0]); + const mkldnn::memory::desc scale_shift_md(mkldnn::memory::dims{gammaAndBeta, gamma.shape()[0]}, + get_mkldnn_type(gamma.dtype()), + mkldnn::memory::format_tag::nc); + auto scale_shift_mem = mkldnn::memory(scale_shift_md, CpuEngine::Get()->get_engine()); + char* ptr = reinterpret_cast(scale_shift_mem.get_data_handle()); + const size_t bytes = scale_shift_md.get_size() / gammaAndBeta; + memcpy(ptr, gamma.data().dptr_, bytes); + memcpy(ptr + bytes, beta.data().dptr_, bytes); + return scale_shift_mem; +} + +void MKLDNNLayerNormFwd::Execute(const LayerNormParam& param, + const OpContext& ctx, + const std::vector& inputs, + const OpReqType& req, + const std::vector& outputs) const { + auto mean_var_md = GetMeanVarDesc(get_mkldnn_type(outputs[layernorm::kMean].dtype()), + outputs[layernorm::kMean].shape()); + auto mean_mem = mkldnn_output_t( + OutDataOp::Noop, + const_cast(outputs[layernorm::kMean]).CreateMKLDNNData(mean_var_md)); + auto variance_mem = + mkldnn_output_t(OutDataOp::Noop, + const_cast(outputs[layernorm::kStd]).CreateMKLDNNData(mean_var_md)); + + auto output_mem = CreateMKLDNNMem(outputs[layernorm::kOut], fwd_pd->dst_desc(), req); + auto scale_shift_mem = GetScaleShiftMem(inputs[layernorm::kGamma], inputs[layernorm::kBeta]); + + mkldnn_args_map_t args = {{MKLDNN_ARG_SRC, *inputs[layernorm::kData].GetMKLDNNData()}, + {MKLDNN_ARG_DST, *output_mem.second}, + {MKLDNN_ARG_MEAN, *mean_mem.second}, + {MKLDNN_ARG_VARIANCE, *variance_mem.second}, + {MKLDNN_ARG_SCALE_SHIFT, scale_shift_mem}}; + + MKLDNNStream::Get()->RegisterPrimArgs(*fwd, args); + CommitOutput(outputs[layernorm::kOut], output_mem); + CommitOutput(outputs[layernorm::kMean], mean_mem); + CommitOutput(outputs[layernorm::kStd], variance_mem); + MKLDNNStream::Get()->Submit(); +} + +MKLDNNLayerNormBwd::MKLDNNLayerNormBwd(const LayerNormParam& param, + const std::vector& inputs, + const mkldnn::memory::desc& data_md, + const mkldnn::memory::desc& diff_md) + : fwd_pd(MKLDNNLayerNormFwd::CreatePrimitiveDesc(param, data_md)), + bwd_pd(CreatePrimitiveDesc(param, data_md, diff_md, *fwd_pd)) { + bwd = std::make_shared(*bwd_pd); +} + +std::shared_ptr MKLDNNLayerNormBwd::CreatePrimitiveDesc( + const LayerNormParam& param, + const mkldnn::memory::desc& data_md, + const mkldnn::memory::desc& diff_md, + const layernorm_fwd_pd_t& layernorm_fwd_pd) { + layernorm_bwd_t::desc layernorm_bwd_desc(dnnl::prop_kind::backward, + diff_md, + data_md, + param.eps, + dnnl::normalization_flags::use_scale_shift); + mkldnn::engine& engine = CpuEngine::Get()->get_engine(); + return std::make_shared(layernorm_bwd_desc, engine, layernorm_fwd_pd); +} + +void MKLDNNLayerNormBwd::Execute(const std::vector& inputs, + const std::vector& outputs, + const std::vector& req) const { + auto scale_shift_mem = + GetScaleShiftMem(inputs[layernorm::kBwdGamma], inputs[layernorm::kBwdBeta]); + auto diff_weights_ndarray = NDArray(scale_shift_mem.get_desc()); + const auto bytes = inputs[layernorm::kBwdGamma].shape()[0] * sizeof(float); + const auto diff_weights_ndaray_data_ptr_plus_bytes = reinterpret_cast( + reinterpret_cast(diff_weights_ndarray.data().dptr_) + bytes); + if (req[layernorm::kBwdGammaGrad] == kAddTo) { + memcpy( + diff_weights_ndarray.data().dptr_, outputs[layernorm::kBwdGammaGrad].data().dptr_, bytes); + memcpy(diff_weights_ndaray_data_ptr_plus_bytes, + outputs[layernorm::kBwdBetaGrad].data().dptr_, + bytes); + } + mkldnn_output_t diff_src_mem = CreateMKLDNNMem( + outputs[layernorm::kBwdDataGrad], bwd_pd->diff_src_desc(), req[layernorm::kBwdDataGrad]); + mkldnn_output_t diff_weights_mem = CreateMKLDNNMem( + diff_weights_ndarray, bwd_pd->diff_weights_desc(), req[layernorm::kBwdGammaGrad]); + mkldnn_args_map_t args = {{MKLDNN_ARG_DIFF_DST, *inputs[layernorm::kBwdOutGrad].GetMKLDNNData()}, + {MKLDNN_ARG_SRC, *inputs[layernorm::kBwdData].GetMKLDNNData()}, + {MKLDNN_ARG_SCALE_SHIFT, scale_shift_mem}, + {MKLDNN_ARG_MEAN, *inputs[layernorm::kBwdMean].GetMKLDNNData()}, + {MKLDNN_ARG_VARIANCE, *inputs[layernorm::kBwdStd].GetMKLDNNData()}, + {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second}, + {MKLDNN_ARG_DIFF_SCALE_SHIFT, *diff_weights_mem.second}}; + MKLDNNStream::Get()->RegisterPrimArgs(*bwd, args); + CommitOutput(outputs[layernorm::kBwdDataGrad], diff_src_mem); + CommitOutput(diff_weights_ndarray, diff_weights_mem); + MKLDNNStream::Get()->Submit(); + // Commit scale_shift diff + memcpy(outputs[layernorm::kBwdGammaGrad].data().dptr_, diff_weights_ndarray.data().dptr_, bytes); + memcpy(outputs[layernorm::kBwdBetaGrad].data().dptr_, + diff_weights_ndaray_data_ptr_plus_bytes, + bytes); +} + +MKLDNNLayerNormBwd& MKLDNNLayerNormBwd::GetCached(const LayerNormParam& param, + const std::vector& inputs) { + using layernorm_bwd_map = std::unordered_map; +#if DMLC_CXX11_THREAD_LOCAL + static thread_local layernorm_bwd_map layer_norm_bwds; +#else + static MX_THREAD_LOCAL layernorm_bwd_map layer_norm_bwds; +#endif + LayerNormSignature key(param); + key.AddSign(inputs[layernorm::kBwdOutGrad]); + key.AddSign(inputs[layernorm::kBwdData]); + key.AddSign(inputs[layernorm::kBwdGamma]); + key.AddSign(inputs[layernorm::kBwdMean]); + key.AddSign(inputs[layernorm::kBwdStd]); + key.AddSign(inputs[layernorm::kBwdBeta]); + key.AddSign(param.eps); + + auto it = layer_norm_bwds.find(key); + if (it == layer_norm_bwds.end()) { + const mkldnn::memory::desc data_md = inputs[layernorm::kBwdData].GetMKLDNNData()->get_desc(); + const mkldnn::memory::desc diff_md = inputs[layernorm::kBwdOutGrad].GetMKLDNNData()->get_desc(); + MKLDNNLayerNormBwd bwd(param, inputs, data_md, diff_md); + it = AddToCache(&layer_norm_bwds, key, bwd); + } + return it->second; +} + +void MKLDNNLayerNormBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const LayerNormParam& param = nnvm::get(attrs.parsed); + MKLDNNLayerNormBwd& bwd = MKLDNNLayerNormBwd::GetCached(param, inputs); + bwd.Execute(inputs, outputs, req); +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_ONEDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 3b8c39fd25e8..44a6b8fb3dd2 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -165,6 +165,18 @@ void MKLDNNBatchDotForward(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs); +/* For layer normalization */ +void MKLDNNLayerNormForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector& req, + const std::vector &outputs); +void MKLDNNLayerNormBackward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + void MKLDNNSum(const mkldnn::memory& arr1, const mkldnn::memory& arr2, const mkldnn::memory& out); void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, From 38d45fc718371e1e824909c6e230407084107c8a Mon Sep 17 00:00:00 2001 From: bkuncer Date: Fri, 20 Aug 2021 19:28:35 +0800 Subject: [PATCH 2/4] change sizeof(float) to mshadow_sizeof(inputs[layernorm::kBwdGamma].dtype()) --- src/operator/nn/mkldnn/mkldnn_layer_norm.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_layer_norm.cc b/src/operator/nn/mkldnn/mkldnn_layer_norm.cc index e9133e5f3d8a..9484ade7d17e 100644 --- a/src/operator/nn/mkldnn/mkldnn_layer_norm.cc +++ b/src/operator/nn/mkldnn/mkldnn_layer_norm.cc @@ -186,7 +186,8 @@ void MKLDNNLayerNormBwd::Execute(const std::vector& inputs, auto scale_shift_mem = GetScaleShiftMem(inputs[layernorm::kBwdGamma], inputs[layernorm::kBwdBeta]); auto diff_weights_ndarray = NDArray(scale_shift_mem.get_desc()); - const auto bytes = inputs[layernorm::kBwdGamma].shape()[0] * sizeof(float); + const auto bytes = inputs[layernorm::kBwdGamma].shape()[0] * + mshadow::mshadow_sizeof(inputs[layernorm::kBwdGamma].dtype()); const auto diff_weights_ndaray_data_ptr_plus_bytes = reinterpret_cast( reinterpret_cast(diff_weights_ndarray.data().dptr_) + bytes); if (req[layernorm::kBwdGammaGrad] == kAddTo) { From 15825b588158dde27373a48dd35dea498bb78147 Mon Sep 17 00:00:00 2001 From: bkuncer Date: Fri, 20 Aug 2021 20:47:24 +0800 Subject: [PATCH 3/4] remove eps from key and unify layernorm_fwd_t/mkldnn::layer_normalization_forward --- src/operator/nn/mkldnn/mkldnn_layer_norm.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_layer_norm.cc b/src/operator/nn/mkldnn/mkldnn_layer_norm.cc index 9484ade7d17e..6b3f4f72cbec 100644 --- a/src/operator/nn/mkldnn/mkldnn_layer_norm.cc +++ b/src/operator/nn/mkldnn/mkldnn_layer_norm.cc @@ -70,7 +70,6 @@ MKLDNNLayerNormFwd& MKLDNNLayerNormFwd::GetCached(const LayerNormParam& param, LayerNormSignature key(param); key.AddSign(data); - key.AddSign(param.eps); auto it = layer_norm_fwds.find(key); if (it == layer_norm_fwds.end()) { @@ -83,7 +82,7 @@ MKLDNNLayerNormFwd& MKLDNNLayerNormFwd::GetCached(const LayerNormParam& param, MKLDNNLayerNormFwd::MKLDNNLayerNormFwd(const LayerNormParam& param, const NDArray& data) { const mkldnn::memory::desc data_md = data.GetMKLDNNData()->get_desc(); fwd_pd = CreatePrimitiveDesc(param, data_md); - fwd = std::make_shared(*fwd_pd); + fwd = std::make_shared(*fwd_pd); } std::shared_ptr MKLDNNLayerNormFwd::CreatePrimitiveDesc( @@ -234,7 +233,6 @@ MKLDNNLayerNormBwd& MKLDNNLayerNormBwd::GetCached(const LayerNormParam& param, key.AddSign(inputs[layernorm::kBwdMean]); key.AddSign(inputs[layernorm::kBwdStd]); key.AddSign(inputs[layernorm::kBwdBeta]); - key.AddSign(param.eps); auto it = layer_norm_bwds.find(key); if (it == layer_norm_bwds.end()) { From 240cac8fdf61f656a52aa06facf2d3677b95b260 Mon Sep 17 00:00:00 2001 From: bkuncer Date: Fri, 20 Aug 2021 21:13:15 +0800 Subject: [PATCH 4/4] add author --- src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h | 1 + src/operator/nn/mkldnn/mkldnn_batch_dot.cc | 1 + src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h | 1 + src/operator/nn/mkldnn/mkldnn_layer_norm.cc | 1 + 4 files changed, 4 insertions(+) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h index 34c3eb9ec8f4..2459ea1a91e4 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h @@ -19,6 +19,7 @@ /*! * \file mkldnn_batch_dot-inl.h + * \author: Bartosz Kuncer, bartosz.kuncer@intel.com */ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot.cc b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc index f7c93ef575ff..87ddb9876023 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_dot.cc +++ b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc @@ -19,6 +19,7 @@ /*! * \file mkldnn_batch_dot.cc + * \author: Bartosz Kuncer, bartosz.kuncer@intel.com */ #if MXNET_USE_ONEDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h index e8938acfb1d8..a14673b140db 100644 --- a/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h @@ -19,6 +19,7 @@ /*! * \file mkldnn_layer_norm-inl.h + * \author: Bartosz Kuncer, bartosz.kuncer@intel.com */ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_layer_norm.cc b/src/operator/nn/mkldnn/mkldnn_layer_norm.cc index 6b3f4f72cbec..8b8e122c18f2 100644 --- a/src/operator/nn/mkldnn/mkldnn_layer_norm.cc +++ b/src/operator/nn/mkldnn/mkldnn_layer_norm.cc @@ -19,6 +19,7 @@ /*! * \file mkldnn_layer_norm.cc + * \author: Bartosz Kuncer, bartosz.kuncer@intel.com */ #if MXNET_USE_ONEDNN == 1