Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,11 +459,6 @@
'zeros_like',
]

if Features().is_enabled('CUDNN'):
FP16_FP32_FUNCS.extend([
'CuDNNBatchNorm',
])

# Functions that have to be cast to FP32 due to possible
# overflows
FP32_FUNCS = [
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -649,11 +649,11 @@ then set ``gamma`` to 1 and its gradient to 0.
.set_attr<nnvm::FGradient>("FGradient", BatchNormGrad)
#if MXNET_USE_ONEDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
#endif
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
.add_argument("gamma", "NDArray-or-Symbol", "gamma array")
.add_argument("beta", "NDArray-or-Symbol", "beta array")
Expand Down
32 changes: 7 additions & 25 deletions src/operator/nn/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
#define ADDTO_BETA_FLAG (1 << 8)

#if MXNET_USE_CUDNN == 1
#include "./cudnn/cudnn_batch_norm-inl.h"
#include "./cudnn/cudnn_batch_norm.h"
#endif

#include "../../../include/mxnet/tensor_blob.h"
Expand Down Expand Up @@ -935,11 +935,6 @@ static void BatchNormalizationBackward(mshadow::Stream<gpu>* s,
(flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0;

if (is_train_and_not_global_stats) {
#ifdef NDEBUG
constexpr bool SMALLER_THREADS = false;
#else
constexpr bool SMALLER_THREADS = true;
#endif
dim3 blocks(gradOutput.ChannelCount());
dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize()));
BatchNormalizationBackwardKernel<DType, AccReal, DeviceTensor1, batchnorm::BNTensor3<DType>>
Expand Down Expand Up @@ -1104,19 +1099,6 @@ void BatchNormBackwardImpl(mshadow::Stream<gpu>* stream,
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoBackward_gpu);
}

#if MXNET_USE_CUDNN == 1
template <typename DType>
static CuDNNBatchNormOp<DType>& GetCuDNNOp(const BatchNormParam& param) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local CuDNNBatchNormOp<DType> op;
#else
static MX_THREAD_LOCAL CuDNNBatchNormOp<DType> op;
#endif
op.Init(param);
return op;
}
#endif

template <>
void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -1132,9 +1114,9 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(
dtype, DType, { GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states); })
if (!param.use_global_stats && !param.cudnn_off &&
CudnnBatchNormSupports(param, inputs[batchnorm::kData])) {
CudnnBatchNormForward(param, ctx, inputs, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
BatchNormForward<gpu, DType, AccReal>(ctx, param, in_data, req, outputs, aux_states);
Expand All @@ -1160,9 +1142,9 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(
dtype, DType, { GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs); })
if (!param.use_global_stats && !param.cudnn_off &&
CudnnBatchNormSupports(param, inputs[3 + batchnorm::kData])) {
CudnnBatchNormBackward(param, ctx, inputs, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
BatchNormBackward<gpu, DType, AccReal>(ctx, param, inputs, req, outputs);
Expand Down
307 changes: 0 additions & 307 deletions src/operator/nn/cudnn/cudnn_batch_norm-inl.h

This file was deleted.

Loading