From 797d1f39cef43322f596c4a15a5f56521f34b55e Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Sun, 26 Sep 2021 21:32:52 -0700 Subject: [PATCH 1/4] Fast cuDNN NHWC kernels support --- src/operator/nn/batch_norm.cc | 2 +- src/operator/nn/batch_norm.cu | 32 +- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 307 ------------------- src/operator/nn/cudnn/cudnn_batch_norm.cc | 258 ++++++++++------ src/operator/nn/cudnn/cudnn_batch_norm.h | 56 ++++ 5 files changed, 236 insertions(+), 419 deletions(-) delete mode 100644 src/operator/nn/cudnn/cudnn_batch_norm-inl.h create mode 100644 src/operator/nn/cudnn/cudnn_batch_norm.h diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index fb12180282ee..5a18363abe39 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -649,11 +649,11 @@ then set ``gamma`` to 1 and its gradient to 0. .set_attr("FGradient", BatchNormGrad) #if MXNET_USE_ONEDNN == 1 .set_attr("TIsMKLDNN", true) +#endif .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) -#endif .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization") .add_argument("gamma", "NDArray-or-Symbol", "gamma array") .add_argument("beta", "NDArray-or-Symbol", "beta array") diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 7807691e7c2b..195423bd1419 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -39,7 +39,7 @@ #define ADDTO_BETA_FLAG (1 << 8) #if MXNET_USE_CUDNN == 1 -#include "./cudnn/cudnn_batch_norm-inl.h" +#include "./cudnn/cudnn_batch_norm.h" #endif #include "../../../include/mxnet/tensor_blob.h" @@ -935,11 +935,6 @@ static void BatchNormalizationBackward(mshadow::Stream* s, (flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0; if (is_train_and_not_global_stats) { -#ifdef NDEBUG - constexpr bool SMALLER_THREADS = false; -#else - constexpr bool SMALLER_THREADS = true; -#endif dim3 blocks(gradOutput.ChannelCount()); dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize())); BatchNormalizationBackwardKernel> @@ -1104,19 +1099,6 @@ void BatchNormBackwardImpl(mshadow::Stream* stream, MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoBackward_gpu); } -#if MXNET_USE_CUDNN == 1 -template -static CuDNNBatchNormOp& GetCuDNNOp(const BatchNormParam& param) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local CuDNNBatchNormOp op; -#else - static MX_THREAD_LOCAL CuDNNBatchNormOp op; -#endif - op.Init(param); - return op; -} -#endif - template <> void BatchNormCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -1132,9 +1114,9 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 - if (!param.use_global_stats && !param.cudnn_off) { - MSHADOW_REAL_TYPE_SWITCH( - dtype, DType, { GetCuDNNOp(param).Forward(ctx, in_data, req, outputs, aux_states); }) + if (!param.use_global_stats && !param.cudnn_off && + CudnnBatchNormSupports(param, inputs[batchnorm::kData])) { + CudnnBatchNormForward(param, ctx, inputs, req, outputs); } else { MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, { BatchNormForward(ctx, param, in_data, req, outputs, aux_states); @@ -1160,9 +1142,9 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 - if (!param.use_global_stats && !param.cudnn_off) { - MSHADOW_REAL_TYPE_SWITCH( - dtype, DType, { GetCuDNNOp(param).Backward(ctx, inputs, req, outputs); }) + if (!param.use_global_stats && !param.cudnn_off && + CudnnBatchNormSupports(param, inputs[3 + batchnorm::kData])) { + CudnnBatchNormBackward(param, ctx, inputs, req, outputs); } else { MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, { BatchNormBackward(ctx, param, inputs, req, outputs); diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h deleted file mode 100644 index 0f79430cfbeb..000000000000 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ /dev/null @@ -1,307 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file cudnn_batch_norm-inl.h - * \brief - * \author Junyuan Xie - */ - -#ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_INL_H_ -#define MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_INL_H_ -#include -#include -#include -#include -#include "../batch_norm-inl.h" - -namespace mxnet { -namespace op { -#if MXNET_USE_CUDNN == 1 -namespace cudnnbatchnorm { -enum CuDNNBatchNormOpInputs { kData, kGamma, kBeta }; -enum CuDNNBatchNormOpOutputs { kOut, kMean, kInvVar }; -enum CuDNNBatchNormOpAuxiliary { kMovingMean, kMovingInvVar }; -} // namespace cudnnbatchnorm - -#if defined(__CUDACC__) -template -class CuDNNBatchNormOp { - STATIC_ASSERT_CUDNN_VERSION_GE(5000); - - public: - CuDNNBatchNormOp() { - using namespace mshadow; - dtype_ = DataType::kCudnnFlag; - // For float16 input type beta, gamma, mean, and average are stored in float32. - // For other input types, these parameters have the same type as input - dtype_param_ = (dtype_ == CUDNN_DATA_HALF) ? kFloat32 : DataType::kFlag; - CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc_)); - internal_aux_states_lock_ = false; - } - - void Init(const BatchNormParam& param) { - CHECK_GE(param.eps, CUDNN_BN_MIN_EPSILON) - << "CuDNN requires eps to be no less than " << CUDNN_BN_MIN_EPSILON; - this->param_ = param; - } - - ~CuDNNBatchNormOp() { - CUDNN_CALL(cudnnDestroyTensorDescriptor(io_desc_)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(mean_desc_)); - } - - void Forward(const OpContext& ctx, - const std::vector& in_data, - const std::vector& req, - const std::vector& out_data, - const std::vector& aux_states) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 3U); - CHECK_EQ(aux_states.size(), 2U); - if (ctx.is_train) { - CHECK_EQ(out_data.size(), 3U); - CHECK_EQ(req.size(), 3U); - } else { - CHECK_GE(out_data.size(), 1U); - CHECK_GE(req.size(), 1U); - } - CHECK_EQ(req[cudnnbatchnorm::kOut], kWriteTo); - CHECK_GE(in_data[cudnnbatchnorm::kData].ndim(), 2); - - Init(in_data[cudnnbatchnorm::kData]); - Stream* s = ctx.get_stream(); - Tensor x = - in_data[cudnnbatchnorm::kData].get_with_shape(shape_, s); - - Tensor y = - out_data[cudnnbatchnorm::kOut].get_with_shape(shape_, s); -#if CUDNN_VERSION >= 7002 - auto mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; -#else - auto mode = CUDNN_BATCHNORM_SPATIAL; -#endif - - MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, { - Tensor gamma = - in_data[cudnnbatchnorm::kGamma].get_with_shape(Shape1(shape_[1]), s); - Tensor beta = - in_data[cudnnbatchnorm::kBeta].get_with_shape(Shape1(shape_[1]), s); - Tensor moving_mean = - aux_states[cudnnbatchnorm::kMovingMean].get_with_shape( - Shape1(shape_[1]), s); - Tensor moving_inv_var = - aux_states[cudnnbatchnorm::kMovingInvVar].get_with_shape( - Shape1(shape_[1]), s); - typename DataType::ScaleType a = 1.0f; - typename DataType::ScaleType b = 0.0f; - - if (param_.fix_gamma) - gamma = 1.f; - - if (ctx.is_train) { - Tensor save_mean = - out_data[cudnnbatchnorm::kMean].get_with_shape(Shape1(shape_[1]), - s); - Tensor save_inv_var = - out_data[cudnnbatchnorm::kInvVar].get_with_shape(Shape1(shape_[1]), - s); - // If the lock on the auxiliary states is set, then this implies that - // the preceding call is also a `Forward()` call, which further - // indicates that we are in the backward mirroring mode, and therefore - // update to the auxiliary states is disabled. This is done by setting - // the `momentum` to `1` (or `factor` to `0`). - float factor = - ((dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) || dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) && - internal_aux_states_lock_) - ? 0 - : (1 - param_.momentum); - CUDNN_CALL(cudnnBatchNormalizationForwardTraining(s->dnn_handle_, - mode, - &a, - &b, - io_desc_, - x.dptr_, - io_desc_, - y.dptr_, - mean_desc_, - gamma.dptr_, - beta.dptr_, - factor, - moving_mean.dptr_, - moving_inv_var.dptr_, - param_.eps, - save_mean.dptr_, - save_inv_var.dptr_)); - } else { - CUDNN_CALL(cudnnBatchNormalizationForwardInference(s->dnn_handle_, - CUDNN_BATCHNORM_SPATIAL, - &a, - &b, - io_desc_, - x.dptr_, - io_desc_, - y.dptr_, - mean_desc_, - gamma.dptr_, - beta.dptr_, - moving_mean.dptr_, - moving_inv_var.dptr_, - param_.eps)); - } - }) - // Set the lock on the auxiliary states. - // If the next call to the operator is a `Forward()` call, - // then `momentum` will be set to `1` and hence auxiliary states will not be updated. - internal_aux_states_lock_ = true; - } - - void Backward(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(inputs.size(), 8U); - CHECK_EQ(outputs.size(), 3U); - - // Rename the inputs and outputs. - const TBlob& out_grad = inputs[0]; - const TBlob& out_mean = inputs[1]; - const TBlob& out_var = inputs[2]; - const TBlob& in_data = inputs[3]; - const TBlob& in_gamma = inputs[4]; - const std::vector& in_grad = outputs; - - Init(in_data); - Stream* s = ctx.get_stream(); - Tensor x = in_data.get_with_shape(shape_, s); - Tensor dx = - in_grad[cudnnbatchnorm::kData].get_with_shape(shape_, s); - Tensor dy = out_grad.get_with_shape(shape_, s); - - const bool global_stats = !ctx.is_train || param_.use_global_stats; - -#if CUDNN_VERSION >= 7002 - auto mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; -#else - auto mode = CUDNN_BATCHNORM_SPATIAL; -#endif - MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, { - Tensor gamma = - in_gamma.get_with_shape(Shape1(shape_[1]), s); - Tensor dbeta = - in_grad[cudnnbatchnorm::kBeta].get_with_shape(Shape1(shape_[1]), s); - Tensor dgamma = - in_grad[cudnnbatchnorm::kGamma].get_with_shape(Shape1(shape_[1]), s); - Tensor save_mean = - out_mean.get_with_shape(Shape1(shape_[1]), s); - Tensor save_inv_var = - out_var.get_with_shape(Shape1(shape_[1]), s); - - typename DataType::ScaleType a = 1.0f; - typename DataType::ScaleType b = 0.0f; - typename DataType::ScaleType b_add = 1.0f; - CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); - - if (param_.fix_gamma) - gamma = 1.f; - - bool grad_add_gamma_beta = - (req[cudnnbatchnorm::kGamma] == kAddTo) || (req[cudnnbatchnorm::kBeta] == kAddTo); - if (grad_add_gamma_beta) { - if (IsBNWriting(req[cudnnbatchnorm::kGamma])) { - dgamma = 0.f; - } - if (IsBNWriting(req[cudnnbatchnorm::kBeta])) { - dbeta = 0.f; - } - } - - CUDNN_CALL( - cudnnBatchNormalizationBackward(s->dnn_handle_, - mode, - &a, - req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b, - &a, - grad_add_gamma_beta ? &b_add : &b, // gamma and beta - io_desc_, - x.dptr_, - io_desc_, - dy.dptr_, - io_desc_, - dx.dptr_, - mean_desc_, - gamma.dptr_, - dgamma.dptr_, - dbeta.dptr_, - param_.eps, - global_stats ? nullptr : save_mean.dptr_, - global_stats ? nullptr : save_inv_var.dptr_)); - if (param_.fix_gamma) - dgamma = 0.f; - }) - // Release the lock on the auxiliary states, so that the next forward pass - // will be able to update the auxiliary states normally. - internal_aux_states_lock_ = false; - } - - private: - void Init(const TBlob& in_data) { - CHECK_GE(param_.axis, 0); - CHECK_LT(param_.axis, in_data.ndim()); - if (param_.axis == 1) { - if (in_data.ndim() == 4) { - for (int i = 0; i < 4; ++i) - shape_[i] = in_data.shape_[i]; - } else { - // when in_data.ndim() != 4 - shape_[0] = in_data.shape_[0]; - shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; - shape_[2] = 1; - shape_[3] = static_cast(in_data.shape_.ProdShape(2, in_data.ndim())); - } - } else { - // reshape to (N, C, 1, D), C is the `param_.axis` dimension - shape_[0] = static_cast(in_data.shape_.ProdShape(0, param_.axis)); - shape_[1] = in_data.shape_[param_.axis]; - shape_[2] = 1; - shape_[3] = static_cast(in_data.shape_.ProdShape(param_.axis + 1, in_data.ndim())); - } - - CUDNN_CALL(cudnnSetTensor4dDescriptor( - io_desc_, CUDNN_TENSOR_NCHW, dtype_, shape_[0], shape_[1], shape_[2], shape_[3])); - CUDNN_CALL(cudnnDeriveBNTensorDescriptor(mean_desc_, io_desc_, CUDNN_BATCHNORM_SPATIAL)); - } - - cudnnDataType_t dtype_; - int dtype_param_; - cudnnTensorDescriptor_t io_desc_, mean_desc_; - mshadow::Shape<4> shape_; - BatchNormParam param_; - bool internal_aux_states_lock_; -}; -#endif // defined(__CUDACC__) - -#endif // MXNET_USE_CUDNN == 1 -} // namespace op -} // namespace mxnet -#endif // MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_INL_H_ diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.cc b/src/operator/nn/cudnn/cudnn_batch_norm.cc index 5ea46f285572..9df046f7d4bb 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm.cc +++ b/src/operator/nn/cudnn/cudnn_batch_norm.cc @@ -23,103 +23,189 @@ * \author Junyuan Xie, Da Zheng */ -#include "./cudnn_batch_norm-inl.h" -#include -#include "../../elemwise_op_common.h" +#include "cudnn_batch_norm.h" + +#include "../../../common/cuda/utils.h" namespace mxnet { namespace op { + #if MXNET_USE_CUDNN == 1 -static bool BatchNormShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector* in_shape, - mxnet::ShapeVector* out_shape) { - using namespace mshadow; - CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, moving_mean, moving_var]"; - const mxnet::TShape& dshape = in_shape->at(0); - if (!mxnet::ndim_is_known(dshape)) - return false; - in_shape->at(1) = mxnet::TShape(Shape1(dshape[1])); - in_shape->at(2) = mxnet::TShape(Shape1(dshape[1])); - in_shape->at(3) = mxnet::TShape(Shape1(dshape[1])); - in_shape->at(4) = mxnet::TShape(Shape1(dshape[1])); - - out_shape->clear(); - out_shape->push_back(dshape); - out_shape->push_back(Shape1(dshape[1])); - out_shape->push_back(Shape1(dshape[1])); - - return true; +namespace { + +struct Globals { + cudnnTensorDescriptor_t io_desc; + cudnnTensorDescriptor_t mean_desc; + bool internal_aux_states_lock = false; + + static Globals& Get() { + thread_local Globals ret; + return ret; + } + + Globals() { + CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc)); + } + + ~Globals() { + CUDNN_CALL(cudnnDestroyTensorDescriptor(io_desc)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(mean_desc)); + } +}; + +void SetDescriptors(const BatchNormParam& param, const TBlob& x) { + CHECK_GE(x.shape_.ndim(), 3); + CHECK(param.axis == 1 || param.axis == x.shape_.ndim() - 1); + + cudnnTensorFormat_t format = param.axis == 1 ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + int n = x.shape_[0]; + int c = x.shape_[param.axis]; + size_t last_spatial_i = param.axis == 1 ? x.shape_.ndim() - 1 : x.shape_.ndim() - 2; + int w = x.shape_[last_spatial_i]; + int h = x.shape_.ProdShape(last_spatial_i - (x.shape_.ndim() - 3), last_spatial_i); + + MSHADOW_REAL_TYPE_SWITCH(x.type_flag_, DType, { + CUDNN_CALL(cudnnSetTensor4dDescriptor(Globals::Get().io_desc, format, + mshadow::DataType::kCudnnFlag, n, c, h, w)); + }) + CUDNN_CALL(cudnnDeriveBNTensorDescriptor(Globals::Get().mean_desc, Globals::Get().io_desc, + CUDNN_BATCHNORM_SPATIAL)); } -static void BatchNormCompute_CPU(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - LOG(FATAL) << "CuDNNBatchNormOp is only available for gpu."; +mshadow::TypeFlag ParamType(int x_type) { + auto xt = static_cast(x_type); + return xt == mshadow::kFloat16 ? mshadow::kFloat32 : xt; } -static void BatchNormGradCompute_CPU(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - LOG(FATAL) << "CuDNNBatchNormOp is only available for gpu."; +} // namespace + +bool CudnnBatchNormSupports(const BatchNormParam& param, const TBlob& x) { + int n = x.shape_.ndim(); + return n >= 3 && (param.axis == 1 || param.axis == n - 1); +} + +void CudnnBatchNormForward(const BatchNormParam& param, const OpContext& ctx, + const std::vector& inputs, const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 5); + if (ctx.is_train) { + CHECK_EQ(outputs.size(), 3); + CHECK_EQ(req.size(), 3); + } else { + CHECK_GE(outputs.size(), 1); + CHECK_GE(req.size(), 1); + } + CHECK_EQ(req[batchnorm::kOut], kWriteTo); + CHECK_GE(inputs[batchnorm::kData].ndim(), 2); + + SetDescriptors(param, inputs[batchnorm::kData]); + + auto s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(ParamType(inputs[batchnorm::kData].type_flag_), DType, { + DType a = 1.0f; + DType b = 0.0f; + if (param.fix_gamma) inputs[batchnorm::kGamma].FlatTo1D(s) = 1.0f; + if (ctx.is_train) { + size_t workspace_size = 0; + CUDNN_CALL(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( + s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, + Globals::Get().io_desc, nullptr, Globals::Get().io_desc, Globals::Get().mean_desc, + nullptr, &workspace_size)); + auto workspace = ctx.requested[0].get_space_internal(workspace_size, "CudnnBatchNormForward"); + + // If the lock on the auxiliary states is set, then this implies that + // the preceding call is also a `Forward()` call, which further + // indicates that we are in the backward mirroring mode, and therefore + // update to the auxiliary states is disabled. This is done by setting + // the `momentum` to `1` (or `factor` to `0`). + double factor = + ((dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) || dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) && + Globals::Get().internal_aux_states_lock) + ? 0 + : (1 - param.momentum); + CUDNN_CALL(cudnnBatchNormalizationForwardTrainingEx( + s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, &a, &b, + Globals::Get().io_desc, inputs[batchnorm::kData].dptr_, + nullptr, nullptr, // zDesc, zData + Globals::Get().io_desc, outputs[batchnorm::kOut].dptr_, + Globals::Get().mean_desc, + inputs[batchnorm::kGamma].dptr_, inputs[batchnorm::kBeta].dptr_, + factor, inputs[batchnorm::kInMovingMean].dptr_, inputs[batchnorm::kInMovingVar].dptr_, + param.eps, outputs[batchnorm::kMean].dptr_, outputs[batchnorm::kVar].dptr_, + nullptr, // activation desc + workspace, workspace_size, + nullptr, 0 // reserveSpace, reserveSpaceSizeInBytes + )); + } else { + CUDNN_CALL(cudnnBatchNormalizationForwardInference( + s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL, &a, &b, + Globals::Get().io_desc, inputs[batchnorm::kData].dptr_, + Globals::Get().io_desc, outputs[batchnorm::kOut].dptr_, + Globals::Get().mean_desc, + inputs[batchnorm::kGamma].dptr_, inputs[batchnorm::kBeta].dptr_, + inputs[batchnorm::kInMovingMean].dptr_, inputs[batchnorm::kInMovingVar].dptr_, + param.eps)); + } + }) + // Set the lock on the auxiliary states. + // If the next call to the operator is a `Forward()` call, + // then `momentum` will be set to `1` and hence auxiliary states will not be updated. + Globals::Get().internal_aux_states_lock = true; } -NNVM_REGISTER_OP(CuDNNBatchNorm) - .describe("Apply batch normalization to input.") - .set_num_inputs(5) - .set_num_outputs(3) - .set_attr_parser(ParamParser) - .set_attr( - "FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data", "gamma", "beta", "moving_mean", "moving_var"}; - }) - .set_attr("FListOutputNames", - [](const NodeAttrs& attrs) { - return std::vector{"output", "mean", "var"}; - }) - .set_attr("FNumVisibleOutputs", - [](const NodeAttrs& attrs) { return 1; }) - .set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - return std::vector{3, 4}; - }) - .set_attr("FInferShape", BatchNormShape) - .set_attr("FCompute", BatchNormCompute_CPU) - .set_attr("FGradient", ElemwiseGradUseInOut{"_backward_CuDNNBatchNorm"}) - .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization") - .add_argument("gamma", "NDArray-or-Symbol", "gamma array") - .add_argument("beta", "NDArray-or-Symbol", "beta array") - .add_argument("moving_mean", "NDArray-or-Symbol", "running mean of input") - .add_argument("moving_var", "NDArray-or-Symbol", "running variance of input") - .add_arguments(BatchNormParam::__FIELDS__()) - .set_attr( - "FSetInputVarAttrOnCompose", - [](const nnvm::NodeAttrs& attrs, nnvm::ObjectPtr var, const int index) { - if (var->attrs.dict.find("__init__") != var->attrs.dict.end()) - return; - if (index == 3) { - var->attrs.dict["__init__"] = "[\"zero\", {}]"; - } else if (index == 4) { - var->attrs.dict["__init__"] = "[\"one\", {}]"; - } - }); - -NNVM_REGISTER_OP(_backward_CuDNNBatchNorm) - .set_num_outputs(5) - .set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - return std::vector{6, 7}; - }) - .set_attr("TIsBackward", true) - .set_attr_parser(ParamParser) - .set_attr("FCompute", BatchNormGradCompute_CPU); - -#endif // MXNET_USE_CUDNN +void CudnnBatchNormBackward(const BatchNormParam& param, const OpContext& ctx, + const std::vector& inputs, const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 8); + CHECK_EQ(outputs.size(), 3); + CHECK_EQ(req.size(), 3); + + SetDescriptors(param, inputs[3 + batchnorm::kData]); + auto s = ctx.get_stream(); + size_t workspace_size = 0; + CUDNN_CALL(cudnnGetBatchNormalizationBackwardExWorkspaceSize( + s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, + Globals::Get().io_desc, Globals::Get().io_desc, Globals::Get().io_desc, nullptr, + Globals::Get().io_desc, Globals::Get().mean_desc, nullptr, &workspace_size)); + auto workspace = ctx.requested[0].get_space_internal(workspace_size, "CudnnBatchNormBackward"); + MSHADOW_REAL_TYPE_SWITCH(ParamType(inputs[3 + batchnorm::kData].type_flag_), DType, { + if (param.fix_gamma) inputs[3 + batchnorm::kGamma].FlatTo1D(s) = 1.0f; + bool grad_add_gamma_beta = req[batchnorm::kGamma] == kAddTo || req[batchnorm::kBeta] == kAddTo; + if (grad_add_gamma_beta) { + if (IsBNWriting(req[batchnorm::kGamma])) + outputs[batchnorm::kGamma].FlatTo1D(s) = 0.0f; + if (IsBNWriting(req[batchnorm::kBeta])) + outputs[batchnorm::kBeta].FlatTo1D(s) = 0.0f; + } + DType a = 1.0f; + DType b = 0.0f; + DType b_add = 1.0f; + const bool global_stats = !ctx.is_train || param.use_global_stats; + CUDNN_CALL(cudnnBatchNormalizationBackwardEx( + s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, + &a, req[batchnorm::kData] == kAddTo ? &b_add : &b, + &a, grad_add_gamma_beta ? &b_add : &b, + Globals::Get().io_desc, inputs[3 + batchnorm::kData].dptr_, + nullptr, nullptr, // yDesc, yData + Globals::Get().io_desc, inputs[batchnorm::kOut].dptr_, + nullptr, nullptr, // dzDesc, dzData + Globals::Get().io_desc, outputs[batchnorm::kData].dptr_, + Globals::Get().mean_desc, + inputs[3 + batchnorm::kGamma].dptr_, inputs[3 + batchnorm::kBeta].dptr_, + outputs[batchnorm::kGamma].dptr_, outputs[batchnorm::kBeta].dptr_, param.eps, + global_stats ? nullptr : inputs[batchnorm::kMean].dptr_, + global_stats ? nullptr : inputs[batchnorm::kVar].dptr_, + nullptr, // activationDesc + workspace, workspace_size, + nullptr, 0 // reserveSpace, reserveSpaceSizeInBytes + )); + if (param.fix_gamma) outputs[batchnorm::kGamma].FlatTo1D(s) = 0.0f; + }) + Globals::Get().internal_aux_states_lock = false; +} +#endif // MXNET_USE_CUDNN == 1 } // namespace op } // namespace mxnet diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.h b/src/operator/nn/cudnn/cudnn_batch_norm.h new file mode 100644 index 000000000000..23559fc012b2 --- /dev/null +++ b/src/operator/nn/cudnn/cudnn_batch_norm.h @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file cudnn_batch_norm.h + * \brief + * \author Junyuan Xie +*/ + +#ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_ +#define MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_ + +#include +#include +#include "../batch_norm-inl.h" + +namespace mxnet { +namespace op { + +#if MXNET_USE_CUDNN == 1 + +STATIC_ASSERT_CUDNN_VERSION_GE(7401); + +bool CudnnBatchNormSupports(const BatchNormParam& param, const TBlob& x); + +void CudnnBatchNormForward(const BatchNormParam& param, const OpContext& ctx, + const std::vector& inputs, const std::vector& req, + const std::vector& outputs); + +void CudnnBatchNormBackward(const BatchNormParam& param, const OpContext& ctx, + const std::vector& inputs, const std::vector& req, + const std::vector& outputs); + +#endif // MXNET_USE_CUDNN == 1 + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_ From bdbfdd44463d0200c3213fedd2e1600000bb7cc9 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 27 Sep 2021 15:32:36 -0700 Subject: [PATCH 2/4] Fix lint errors --- src/operator/nn/cudnn/cudnn_batch_norm.cc | 6 ++---- src/operator/nn/cudnn/cudnn_batch_norm.h | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.cc b/src/operator/nn/cudnn/cudnn_batch_norm.cc index 9df046f7d4bb..300f35c1937b 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm.cc +++ b/src/operator/nn/cudnn/cudnn_batch_norm.cc @@ -136,8 +136,7 @@ void CudnnBatchNormForward(const BatchNormParam& param, const OpContext& ctx, param.eps, outputs[batchnorm::kMean].dptr_, outputs[batchnorm::kVar].dptr_, nullptr, // activation desc workspace, workspace_size, - nullptr, 0 // reserveSpace, reserveSpaceSizeInBytes - )); + nullptr, 0)); // reserveSpace, reserveSpaceSizeInBytes } else { CUDNN_CALL(cudnnBatchNormalizationForwardInference( s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL, &a, &b, @@ -199,8 +198,7 @@ void CudnnBatchNormBackward(const BatchNormParam& param, const OpContext& ctx, global_stats ? nullptr : inputs[batchnorm::kVar].dptr_, nullptr, // activationDesc workspace, workspace_size, - nullptr, 0 // reserveSpace, reserveSpaceSizeInBytes - )); + nullptr, 0)); // reserveSpace, reserveSpaceSizeInBytes if (param.fix_gamma) outputs[batchnorm::kGamma].FlatTo1D(s) = 0.0f; }) Globals::Get().internal_aux_states_lock = false; diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.h b/src/operator/nn/cudnn/cudnn_batch_norm.h index 23559fc012b2..57249b184944 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm.h @@ -27,8 +27,8 @@ #ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_ #define MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_ -#include #include +#include #include "../batch_norm-inl.h" namespace mxnet { From 44c697fe5386f9ed1b0a048f03b7664955aa156a Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 27 Sep 2021 17:44:56 -0700 Subject: [PATCH 3/4] Get rid of a warning --- .../nn/cudnn/{cudnn_batch_norm.cc => cudnn_batch_norm.cu} | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) rename src/operator/nn/cudnn/{cudnn_batch_norm.cc => cudnn_batch_norm.cu} (99%) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.cc b/src/operator/nn/cudnn/cudnn_batch_norm.cu similarity index 99% rename from src/operator/nn/cudnn/cudnn_batch_norm.cc rename to src/operator/nn/cudnn/cudnn_batch_norm.cu index 300f35c1937b..bed274fa4a03 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm.cc +++ b/src/operator/nn/cudnn/cudnn_batch_norm.cu @@ -18,7 +18,8 @@ */ /*! - * \file cudnn_batch_norm.cc + * Copyright (c) 2015 by Contributors + * \file cudnn_batch_norm.cu * \brief * \author Junyuan Xie, Da Zheng */ From 01f4edf684e6d96f55edb31b0f195821aa8eb0b1 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 27 Sep 2021 20:33:44 -0700 Subject: [PATCH 4/4] Remove CuDNNBatchNorm from AMP lists --- python/mxnet/amp/lists/symbol_fp16.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py index d942051c0398..009586ed28f8 100644 --- a/python/mxnet/amp/lists/symbol_fp16.py +++ b/python/mxnet/amp/lists/symbol_fp16.py @@ -459,11 +459,6 @@ 'zeros_like', ] -if Features().is_enabled('CUDNN'): - FP16_FP32_FUNCS.extend([ - 'CuDNNBatchNorm', - ]) - # Functions that have to be cast to FP32 due to possible # overflows FP32_FUNCS = [