Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/operator/nn/layer_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ namespace op {

namespace layernorm {
enum LayerNormOpInputs {kData, kGamma, kBeta}; // kGamma: scaling parameters, kBeta: shift biases
enum LayerNormOpOutputs {kOut, kMean, kStd}; // req, out_data
enum LayerNormOpOutputs {kOut, kMean, kStd}; // indices for req, out_data
enum LayerNormOpInputsBwd {kBwdOutGrad, kBwdData, kBwdGamma, kBwdMean, kBwdStd, kBwdBeta};
enum LayerNormOpOutputsBwd {kBwdDataGrad, kBwdGammaGrad, kBwdBetaGrad};
} // namespace layernorm

struct LayerNormParam : public dmlc::Parameter<LayerNormParam> {
Expand All @@ -71,6 +73,11 @@ struct LayerNormParam : public dmlc::Parameter<LayerNormParam> {
(*dict)["eps"] = eps_s.str();
(*dict)["output_mean_var"] = output_mean_var_s.str();
}

bool operator==(const LayerNormParam& other) const {
return (this->axis == other.axis && this->eps == other.eps &&
this->output_mean_var == other.output_mean_var);
}
};

inline int GetRealAxis(int axis, int ndim) {
Expand Down Expand Up @@ -257,7 +264,11 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
#if MXNET_USE_ONEDNN == 1
CHECK_EQ(inputs.size(), 6U); // additional beta tensor
#else
CHECK_EQ(inputs.size(), 5U);
#endif
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
int axis = param.axis;
if (axis < 0) {
Expand Down Expand Up @@ -313,4 +324,17 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,

} // namespace op
} // namespace mxnet

namespace std {
template <>
struct hash<mxnet::op::LayerNormParam> {
size_t operator()(const mxnet::op::LayerNormParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.axis);
ret = dmlc::HashCombine(ret, val.eps);
ret = dmlc::HashCombine(ret, val.output_mean_var);
return ret;
}
};
} // namespace std
#endif // MXNET_OPERATOR_NN_LAYER_NORM_INL_H_
66 changes: 66 additions & 0 deletions src/operator/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@
#include "layer_norm-inl.h"
#include <nnvm/op_attr_types.h>
#include "../elemwise_op_common.h"
#if MXNET_USE_ONEDNN == 1
#include "./mkldnn/mkldnn_base-inl.h"
#include "./mkldnn/mkldnn_ops-inl.h"
#endif // MXNET_USE_ONEDNN

#if MSHADOW_USE_MKL == 1
#include "../mkl_functions-inl.h"
Expand Down Expand Up @@ -392,6 +396,50 @@ void LayerNormGradCompute<cpu>(const nnvm::NodeAttrs& attrs,
return LayerNormGradComputeGeneral<cpu>(attrs, ctx, inputs, req, outputs);
}

#if MXNET_USE_ONEDNN == 1
static bool LayerNormInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK(!in_attrs->empty());

return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
}

static void LayerNormComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
if (SupportMKLDNNLayerNorm(param, inputs)) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNRun(MKLDNNLayerNormForward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(LayerNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
} else {
FallBackCompute(LayerNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
}

static void LayerNormGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
if (SupportMKLDNNLayerNorm(param, inputs)) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNRun(MKLDNNLayerNormBackward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(LayerNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
} else {
FallBackCompute(LayerNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
}
#endif

NNVM_REGISTER_OP(LayerNorm)
.add_alias("_npx_layer_norm")
.describe(R"code(Layer normalization.
Expand Down Expand Up @@ -439,6 +487,11 @@ axis to be the last item in the input shape.
.set_attr<mxnet::FInferShape>("FInferShape", LayerNormShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 3>)
.set_attr<FCompute>("FCompute<cpu>", LayerNormCompute<cpu>)
#if MXNET_USE_ONEDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FInferStorageType>("FInferStorageType", LayerNormInferStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", LayerNormComputeExCPU)
#endif
.set_attr<nnvm::FGradient>("FGradient", [](const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads;
Expand All @@ -447,6 +500,10 @@ axis to be the last item in the input shape.
heads.push_back(n->inputs[1]); // gamma
heads.emplace_back(n, 1, 0); // mean
heads.emplace_back(n, 2, 0); // std
#if MXNET_USE_ONEDNN == 1
heads.push_back(n->inputs[2]); // beta - needed for MKLDNN backward propagation;
// added at the end in case of fallback to non MKLDNN version
#endif
return MakeGradNode("_backward_LayerNorm", n, heads, n->attrs.dict);
})
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -464,11 +521,20 @@ axis to be the last item in the input shape.


NNVM_REGISTER_OP(_backward_LayerNorm)
#if MXNET_USE_ONEDNN == 1
.set_num_inputs(6)
#else
.set_num_inputs(5)
#endif
.set_num_outputs(3)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<LayerNormParam>)
.set_attr<FCompute>("FCompute<cpu>", LayerNormGradCompute<cpu>)
#if MXNET_USE_ONEDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", LayerNormInferStorageType)
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", LayerNormGradComputeExCPU)
#endif
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
});
Expand Down
4 changes: 4 additions & 0 deletions src/operator/nn/layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,11 @@ void LayerNormGradGPUContig(const LayerNormParam param,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
#if MXNET_USE_ONEDNN == 1
CHECK_EQ(inputs.size(), 6U); // additional beta tensor
#else
CHECK_EQ(inputs.size(), 5U);
#endif
const TBlob out_grad = inputs[0];
const TBlob in_data = inputs[1];
const TBlob gamma = inputs[2];
Expand Down
2 changes: 2 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ struct SoftmaxParam;
struct SoftmaxOutputParam;
struct TransposeParam;
struct ReshapeParam;
struct LayerNormParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray& input);
bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param);
Expand All @@ -216,6 +217,7 @@ bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param,
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam& param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray& data);
bool SupportMKLDNNBatchDot(const std::vector<NDArray>& inputs, const NDArray& output);
bool SupportMKLDNNLayerNorm(const LayerNormParam& param, const std::vector<NDArray> &inputs);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

/*!
* \file mkldnn_batch_dot-inl.h
* \author: Bartosz Kuncer, bartosz.kuncer@intel.com
*/

#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H_
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/mkldnn/mkldnn_batch_dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

/*!
* \file mkldnn_batch_dot.cc
* \author: Bartosz Kuncer, bartosz.kuncer@intel.com
*/

#if MXNET_USE_ONEDNN == 1
Expand Down
103 changes: 103 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_layer_norm-inl.h
* \author: Bartosz Kuncer, bartosz.kuncer@intel.com
*/
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H_

#if MXNET_USE_ONEDNN == 1

#include <utility>
#include <vector>

#include "../layer_norm-inl.h"
#include "./mkldnn_base-inl.h"
#include "./mkldnn_ops-inl.h"

namespace mxnet {
namespace op {

using layernorm_fwd_t = mkldnn::layer_normalization_forward;
using layernorm_fwd_pd_t = mkldnn::layer_normalization_forward::primitive_desc;

using layernorm_bwd_t = mkldnn::layer_normalization_backward;
using layernorm_bwd_pd_t = mkldnn::layer_normalization_backward::primitive_desc;

typedef ParamOpSign<LayerNormParam> LayerNormSignature;

class MKLDNNLayerNormFwd {
public:
static MKLDNNLayerNormFwd& GetCached(const LayerNormParam& param,
const OpContext& ctx,
const NDArray& data);

MKLDNNLayerNormFwd(const LayerNormParam& param, const NDArray& data);

static std::shared_ptr<layernorm_fwd_pd_t> CreatePrimitiveDesc(
const LayerNormParam& param,
const mkldnn::memory::desc& src_md);

void Execute(const LayerNormParam& param,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const OpReqType& req,
const std::vector<NDArray>& outputs) const;

~MKLDNNLayerNormFwd() {}

private:
std::shared_ptr<layernorm_fwd_t> fwd;
std::shared_ptr<layernorm_fwd_pd_t> fwd_pd;
};

class MKLDNNLayerNormBwd {
public:
static MKLDNNLayerNormBwd& GetCached(const LayerNormParam& param,
const std::vector<NDArray>& inputs);

MKLDNNLayerNormBwd(const LayerNormParam& param,
const std::vector<NDArray>& inputs,
const mkldnn::memory::desc& data_md,
const mkldnn::memory::desc& diff_md);

static std::shared_ptr<layernorm_bwd_pd_t> CreatePrimitiveDesc(
const LayerNormParam& param,
const mkldnn::memory::desc& data_md,
const mkldnn::memory::desc& diff_md,
const layernorm_fwd_pd_t& layernorm_fwd_pd);

void Execute(const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs,
const std::vector<OpReqType>& req) const;

~MKLDNNLayerNormBwd() {}

private:
std::shared_ptr<layernorm_bwd_t> bwd;
std::shared_ptr<layernorm_fwd_pd_t> fwd_pd;
std::shared_ptr<layernorm_bwd_pd_t> bwd_pd;
};

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H__
Loading