From 41f00cd72a2f4c5436255d017214e203a99893f1 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Mon, 13 May 2019 01:15:29 +0800 Subject: [PATCH] add acceleration fix lint fix lint fix bug further accelerate fix fix bug fix bug --- src/operator/nn/layer_norm-inl.h | 23 +- src/operator/nn/layer_norm.cc | 16 + src/operator/nn/layer_norm.cu | 653 +++++++++++++++++++++++++++++++ 3 files changed, 686 insertions(+), 6 deletions(-) diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index 3fa2e91681fe..29224243dc40 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -63,12 +63,17 @@ struct LayerNormParam : public dmlc::Parameter { } }; - template void LayerNormCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, - const std::vector& outputs) { + const std::vector& outputs); + +template +void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; const LayerNormParam& param = nnvm::get(attrs.parsed); @@ -146,6 +151,12 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, {kWriteTo}, {outputs[0]}); } +template +void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + /* Calculate the gradient of layer normalization. We have the following gradient for gamma, beta and x: @@ -157,10 +168,10 @@ grad_beta = sum(og, exclude_axis) grad_x = w - mean(w, axis) - \bar{x} * mean(w * \bar{x}, axis) */ template -void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(inputs.size(), 5U); diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index 2e47503a3318..5b0aca6910f7 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -65,6 +65,22 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, } +template<> +void LayerNormCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + return LayerNormComputeGeneral(attrs, ctx, inputs, req, outputs); +} + +template<> +void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + return LayerNormGradComputeGeneral(attrs, ctx, inputs, req, outputs); +} + NNVM_REGISTER_OP(LayerNorm) .describe(R"code(Layer normalization. diff --git a/src/operator/nn/layer_norm.cu b/src/operator/nn/layer_norm.cu index a146131294f0..b63046fc0026 100644 --- a/src/operator/nn/layer_norm.cu +++ b/src/operator/nn/layer_norm.cu @@ -24,9 +24,662 @@ */ #include "./layer_norm-inl.h" +using namespace mshadow::cuda; + namespace mxnet { namespace op { +template +__device__ __forceinline__ DType WARP_SHFL(DType value, int src_lane, + int width = 32, unsigned int mask = 0xffffffff) { +#if CUDA_VERSION >= 9000 + return __shfl_sync(mask, value, src_lane, width); +#else + return __shfl(value, src_lane, width); +#endif +} + +template +__device__ __forceinline__ DType WARP_SHFL_XOR(DType value, int laneMask, + int width = 32, unsigned int mask = 0xffffffff) { +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + + +/* A single updating step of the Welford's online algorithm to calculate the mean and variance. + * The value 'curr' will be accumulated to the (mean, sigma2, count) triplet. + * + */ +template +__device__ __forceinline__ void StepWelfordOnlineSum(const DType curr, + DType& mean, //NOLINT + DType& sigma2, //NOLINT + DType& count) { //NOLINT + count += DType(1); + DType delta = curr - mean; + mean += delta / count; + sigma2 += delta * (curr - mean); +} + +/* Merge the mean/variance of two partitions. It's the key step of the Chan's parallel algorithm. + * The (lhs_mean, lhs_sigma2, lhs_count) will be merged into (rhs_mean, rhs_sigma2, rhs_count) + * + * See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance for more details. + * + * TODO(sxjscience) Explore the possibility of int lhs_count and rhs_count + */ +template +__device__ __inline__ void ChanMergePartition(const DType lhs_mean, + const DType lhs_sigma2, + const DType lhs_count, + DType& rhs_mean, //NOLINT + DType& rhs_sigma2, //NOLINT + DType& rhs_count) { //NOLINT + DType delta = rhs_mean - lhs_mean; + DType nA = lhs_count; + DType nB = rhs_count; + rhs_count = nA + nB; + if (rhs_count > DType(0)) { + nA = nA / rhs_count; + nB = nB / rhs_count; + rhs_mean = nA * lhs_mean + nB * rhs_mean; + rhs_sigma2 = rhs_sigma2 + lhs_sigma2 + delta * delta * nA * nB * rhs_count; + } else { + rhs_mean = DType(0); + rhs_sigma2 = DType(0); + } +} + + +template +__device__ __forceinline__ void BlockWelfordOnlineSum(const DType* __restrict__ col_vals, + const int nchannel, + AType& mean, //NOLINT + AType& sigma2, //NOLINT + IType& count) { //NOLINT + int tid = threadIdx.x + threadIdx.y * blockDim.x; + const int nthread = blockDim.x * blockDim.y; + // Each thread takes charge of 4 consecutive numbers. This should optimize the loading speed using + // vectorized types like float4. + // Also, to minimize branch divergence, we split the for-loop into two parts. + int l = 4 * tid; + for (; l + 3 < nchannel; l += 4 * nthread) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + StepWelfordOnlineSum(col_vals[l + i], mean, sigma2, count); + } + } + for (; l < nchannel; ++l) { + StepWelfordOnlineSum(col_vals[l], mean, sigma2, count); + } +} + +/* Fused CUDA kernel for the forward pass of layer normalization. + * It computes the LayerNorm when axis=-1, i.e., contiguous reduction scenario. + * Shape of the input tensors: + * in_data = (nbatch, nchannel) + * gamma = (nchannel,) + * beta = (nchannel,) + * out_data = (nchannel,) + * mean_data = (nbatch,) + * var_data = (nbatch,) + * It's always launched with (blockDim.x, blockDim.y) = (WARP_SIZE, blockDim.y) + * Also, when blockDim.y > 1, it requires shared memory that has size: + * sizeof(DType) * blockDim.y + sizeof(DType) * blockDim.y / 2 + */ +template +__global__ void LayerNormFusedForwardKernelContig(const int nbatch, + const int nchannel, + const DType eps, + const DType* __restrict__ in_data, + const DType* __restrict__ gamma, + const DType* __restrict__ beta, + DType* __restrict__ out_data, + DType* __restrict__ mean_data, + DType* __restrict__ std_data) { + int bid = blockIdx.x + blockIdx.y * gridDim.x; + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int nthread = blockDim.x * blockDim.y; + DType count = 0; + DType mean = 0; + DType sigma2 = 0; + + if (bid < nbatch) { + extern __shared__ char buf[]; // Shared memory + const DType* col_vals = in_data + bid * nchannel; + BlockWelfordOnlineSum(col_vals, nchannel, mean, sigma2, count); + + // Merge the mean/sigma2 within a warp + // Use the Chan's Parallel Algorithm to merge all (mean, sigma2, counts) + // within a warp of threads. + // After calling the function, threadIdx.x == 0 will store the result of + // the aggregated (mean, sigma2, counts). + for (int mask = 16; mask > 0; mask >>= 1) { + DType meanB = WARP_SHFL_XOR(mean, mask); + DType sigma2B = WARP_SHFL_XOR(sigma2, mask); + DType countB = WARP_SHFL_XOR(count, mask); + ChanMergePartition(meanB, sigma2B, countB, mean, sigma2, count); + } + if (blockDim.y > 1) { + // Inter-warp reduction. Copy the upper-half of the warps to shared memory + // and merge with the lower-half warp + DType* mean_buf = reinterpret_cast(buf); + DType* sigma2_buf = reinterpret_cast(buf + sizeof(DType) * blockDim.y / 2 * 32); + DType* count_buf = reinterpret_cast(buf + sizeof(DType) * blockDim.y * 32); + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + mean_buf[idx] = mean; + sigma2_buf[idx] = sigma2; + count_buf[idx] = count; + } + __syncthreads(); + if (threadIdx.y < offset) { + const int idx = threadIdx.y * blockDim.x + threadIdx.x; + ChanMergePartition(mean_buf[idx], sigma2_buf[idx], count_buf[idx], mean, sigma2, count); + } + __syncthreads(); + } + // Broadcast the result to all threads + if (threadIdx.y == 0) { + mean_buf[threadIdx.x] = mean; + sigma2_buf[threadIdx.x] = sigma2; + } + __syncthreads(); + mean = mean_buf[threadIdx.x]; + sigma2 = sigma2_buf[threadIdx.x] / nchannel; + } else { + sigma2 /= nchannel; + } + // Calculate the out_data: gamma * (x - mean) / sqrt(var + eps) + beta + DType std_eps = sqrt(sigma2 + eps); + DType invstd_eps = DType(1.0) / std_eps; + DType* out_col_val = out_data + bid * nchannel; + + if (gamma != NULL && beta != NULL) { + for (int i = tid; i < nchannel; i += nthread) { + out_col_val[i] = gamma[i] * invstd_eps * (col_vals[i] - mean) + beta[i]; + } + } else if (gamma == NULL && beta != NULL) { + for (int i = tid; i < nchannel; i += nthread) { + out_col_val[i] = invstd_eps * (col_vals[i] - mean) + beta[i]; + } + } else if (gamma != NULL && beta == NULL) { + for (int i = tid; i < nchannel; i += nthread) { + out_col_val[i] = gamma[i] * invstd_eps * (col_vals[i] - mean); + } + } else { + for (int i = tid; i < nchannel; i += nthread) { + out_col_val[i] = invstd_eps * (col_vals[i] - mean); + } + } + // Write the out_data and var_data + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean_data[bid] = mean; + std_data[bid] = std_eps; + } + } +} + +void LayerNormGPUContig(const LayerNormParam param, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 3U); + mxnet::TShape data_shape(2, 0); + mxnet::TShape mean_shape(1, 0); + size_t in_ndim = inputs[layernorm::kData].ndim(); + data_shape[0] = mean_shape[0] = inputs[layernorm::kData].shape_.ProdShape(0, in_ndim - 1); + data_shape[1] = inputs[layernorm::kData].shape_[in_ndim - 1]; + const TBlob in_data = inputs[layernorm::kData].reshape(data_shape); + const TBlob gamma = inputs[layernorm::kGamma]; + const TBlob beta = inputs[layernorm::kBeta]; + const TBlob out_data = outputs[layernorm::kOut].reshape(data_shape); + const TBlob mean_data = outputs[layernorm::kMean].reshape(mean_shape); + const TBlob std_data = outputs[layernorm::kStd].reshape(mean_shape); + // Make sure the inputs are contiguous + CHECK_EQ(in_data.CheckContiguous(), true); + CHECK_EQ(gamma.CheckContiguous(), true); + CHECK_EQ(beta.CheckContiguous(), true); + CHECK_EQ(out_data.CheckContiguous(), true); + CHECK_EQ(mean_data.CheckContiguous(), true); + CHECK_EQ(std_data.CheckContiguous(), true); + + // Lauch the kernel. The dynamic shared memory size is + // sizeof(DType) * blockDim.y * blockDim.x + sizeof(DType) * blockDim.y / 2 * blockDim.x + int nbatch = data_shape[0]; + int nchannel = data_shape[1]; + float eps = param.eps; + int ngrid_x = (nbatch > kMaxGridDim) ? (nbatch + kBaseGridNum - 1) / kBaseGridNum : nbatch; + int ngrid_y = (nbatch > kMaxGridDim) ? kBaseGridNum : 1; + int nthread_y; + const dim3 dimGrid(ngrid_x, ngrid_y); + if (nchannel <= 128) { + nthread_y = 1; + } else if (nchannel <= 512) { + nthread_y = 2; + } else { + nthread_y = 4; + } + cudaStream_t stream = Stream::GetStream(ctx.get_stream()); + const dim3 dimBlock(32, nthread_y); + MSHADOW_REAL_TYPE_SWITCH(in_data.type_flag_, DType, { + int nshared = nthread_y > 1 ? nthread_y * 32 * sizeof(DType) + + (nthread_y / 2) * 32 * sizeof(DType) : 0; + CheckLaunchParam(dimGrid, dimBlock); + LayerNormFusedForwardKernelContig<<>> + (nbatch, nchannel, static_cast(eps), + in_data.dptr(), gamma.dptr(), beta.dptr(), + out_data.dptr(), mean_data.dptr(), std_data.dptr()); + }); + MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedForwardKernelContig); +} + +template<> +void LayerNormCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const LayerNormParam& param = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) return; + CHECK_NE(req[0], kAddTo); + int axis = param.axis; + if (axis < 0) { + axis += static_cast(inputs[0].ndim()); + } + CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis; + if (axis == inputs[0].ndim() - 1) { + // Try to use the accelerated CUDA kernels + return LayerNormGPUContig(param, ctx, inputs, req, outputs); + } + return LayerNormComputeGeneral(attrs, ctx, inputs, req, outputs); +} + + +/* Fused CUDA kernel for calculating the gradient w.r.t gamma/beta in LayerNorm when axis=-1 + * (Contiguous case). + * The gradient of gamma and beta are: + * d_gamma = sum(out_grad * (x - mean) / std, axis=0) + * d_beta = sum(out_grad, axis=0) + * + * We compute the gradient (mainly reduction over a non-contiguous axis) using two steps to + * improve the parallelism. + * + * In the first step, we divide the rows uniformly into K parts. K independent threadblocks are used + * to calculate the partial reduction result of each part. Illustrated below: + * + * 1st Block 2nd Block 3rd Block k-th Block + * | --------------- | ---------------- | --------------- | ... | ---------------- | + * | --------------- | ---------------- | --------------- | ... | ---------------- | + * | --------------- | ---------------- | --------------- | ... | ---------------- | + * | --------------- | ---------------- | --------------- | ... | ---------------- | + * part_gamma[0] part_gamma[1] part_gamma[2] part_gamma[k-1] + * part_beta[0] part_beta[1] part_beta[2] part_beta[k-1] + * + * + * In the second step, we sum up the row-values in part_gamma and part_beta. + * + * This `LayerNormFusedBackwardKernel_PartGammaBeta` function implements the first step and + * `LayerNormFusedBackwardKernel_GammaBeta` implements the second step. + */ +template +__global__ void LayerNormFusedBackwardKernel_PartGammaBeta(const int nbatch, + const int nchannel, + const DType* __restrict__ in_data, + const DType* __restrict__ out_grad, + const DType* __restrict__ mean_data, + const DType* __restrict__ std_data, + DType* part_gamma_grad, + DType* part_beta_grad) { + extern __shared__ char buf[]; + DType* d_buf = reinterpret_cast(buf); + const int npart = gridDim.y; + const int block_row_num = (nbatch + npart - 1) / npart; + // The rows are divided into `npart` parts. Each threadblock calculates the reduction result + // within the corresponding row ranges. + int row_stride = blockDim.x + 1; + const int c = blockIdx.x * blockDim.x + threadIdx.x; + int r_begin = blockIdx.y * block_row_num; + int r_end = min((blockIdx.y + 1) * block_row_num, nbatch); + DType* buf_gamma_grad = d_buf; + DType* buf_beta_grad = d_buf + blockDim.y * row_stride; + DType local_gamma_grad = 0; + DType local_beta_grad = 0; + + if (c < nchannel) { + for (int r_b = r_begin; r_b < r_end; r_b += blockDim.y) { + int r = r_b + threadIdx.y; + if (r < r_end) { + DType local_mean = mean_data[r]; + DType local_std = std_data[r]; + int read_idx = r * nchannel + c; + local_gamma_grad += (in_data[read_idx] - local_mean) / local_std + * out_grad[read_idx]; + local_beta_grad += out_grad[read_idx]; + } + } + } + buf_gamma_grad[threadIdx.y * row_stride + threadIdx.x] = local_gamma_grad; + buf_beta_grad[threadIdx.y * row_stride + threadIdx.x] = local_beta_grad; + __syncthreads(); + for (int offset = blockDim.y/2; offset > 1; offset >>= 1) { + if (threadIdx.y < offset) { + int idx1 = threadIdx.y * row_stride + threadIdx.x; + int idx2 = (threadIdx.y + offset) * row_stride + threadIdx.x; + buf_gamma_grad[idx1] += buf_gamma_grad[idx2]; + buf_beta_grad[idx1] += buf_beta_grad[idx2]; + } + __syncthreads(); + } + if (threadIdx.y == 0 && c < nchannel) { + part_gamma_grad[blockIdx.y * nchannel + c] = buf_gamma_grad[threadIdx.x] + + buf_gamma_grad[threadIdx.x + row_stride]; + part_beta_grad[blockIdx.y * nchannel + c] = buf_beta_grad[threadIdx.x] + + buf_beta_grad[threadIdx.x + row_stride]; + } +} + +template +__global__ void LayerNormFusedBackwardKernel_GammaBeta(const int nbatch, + const int nchannel, + const int npart, + const DType* __restrict__ part_gamma_grad, + const DType* __restrict__ part_beta_grad, + DType* gamma_grad, + DType* beta_grad) { + const int c = blockIdx.x * blockDim.x + threadIdx.x; + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + if (c < nchannel) { + extern __shared__ char buf[]; + DType* buf_gamma_grad = reinterpret_cast(buf); + DType* buf_beta_grad = reinterpret_cast(buf) + blockDim.x * blockDim.y; + buf_gamma_grad[tid] = 0; + buf_beta_grad[tid] = 0; + for (int r = threadIdx.y; r < npart; r += blockDim.y) { + buf_gamma_grad[tid] += part_gamma_grad[r * nchannel + c]; + buf_beta_grad[tid] += part_beta_grad[r * nchannel + c]; + } + __syncthreads(); + // Begin for inter-warp reduce + if (npart > 1) { + for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset) { + int idx1 = tid; + int idx2 = tid + offset * blockDim.x; + buf_gamma_grad[idx1] += buf_gamma_grad[idx2]; + buf_beta_grad[idx1] += buf_beta_grad[idx2]; + } + __syncthreads(); + } + } + if (threadIdx.y == 0) { + if (gamma_grad) { + if (gamma_addto) { + gamma_grad[c] += buf_gamma_grad[threadIdx.x]; + } else { + gamma_grad[c] = buf_gamma_grad[threadIdx.x]; + } + } + if (beta_grad) { + if (beta_addto) { + beta_grad[c] += buf_beta_grad[threadIdx.x]; + } else { + beta_grad[c] = buf_beta_grad[threadIdx.x]; + } + } + } + } +} + +/* + * + * + */ +template +__global__ void LayerNormFusedBackwardKernel_Data(const int nbatch, + const int nchannel, + const DType* __restrict__ in_data, + const DType* __restrict__ out_grad, + const DType* __restrict__ mean_data, + const DType* __restrict__ std_data, + const DType* __restrict__ gamma, + DType* data_grad) { + int bid = blockIdx.x + blockIdx.y * gridDim.x; + const int nthread = blockDim.x * blockDim.y; + if (bid < nbatch) { + // Shared memory with size blockDim.y * blockDim.x * sizeof(DType) + extern __shared__ char buf[]; + int tid = threadIdx.x + threadIdx.y * blockDim.x; + // 1. Calculate: mean(out_grad * gamma / std, axis=-1) + // mean(out_grad * gamma / std * (x - mean) / std, axis=-1) + DType sum_val0 = 0; // Stores mean(out_grad * gamma / std, axis=-1) + DType sum_val1 = 0; // Stores mean(out_grad * gamma / std * (x - mean) / std, axis=-1) + DType mean = mean_data[bid]; + DType invstd_eps = DType(1) / std_data[bid]; + int l = LOAD_UNROLL * tid; + for (; l + LOAD_UNROLL - 1 < nchannel; l += nthread * LOAD_UNROLL) { +#pragma unroll + for (int i = 0; i < LOAD_UNROLL; ++i) { + DType ele_og = out_grad[bid * nchannel + l + i]; + DType ele_x = in_data[bid * nchannel + l + i]; + DType ele_gamma = gamma[l + i]; + sum_val0 += ele_og * ele_gamma * invstd_eps; + sum_val1 += ele_og * ele_gamma * (ele_x - mean) * invstd_eps * invstd_eps; + } + } + for (; l < nchannel; ++l) { + DType ele_og = out_grad[bid * nchannel + l]; + DType ele_x = in_data[bid * nchannel + l]; + DType ele_gamma = gamma[l]; + sum_val0 += ele_og * ele_gamma * invstd_eps; + sum_val1 += ele_og * ele_gamma * (ele_x - mean) * invstd_eps * invstd_eps; + } + // Intra-warp reduction (all-reduce) + for (int mask = blockDim.x / 2; mask > 0; mask >>= 1) { + sum_val0 += WARP_SHFL_XOR(sum_val0, mask); + sum_val1 += WARP_SHFL_XOR(sum_val1, mask); + } + // Inter-warp reduction (all-reduce) + if (blockDim.y > 1) { + DType* sum_val0_buf = reinterpret_cast(buf); + DType* sum_val1_buf = + reinterpret_cast(buf + blockDim.y / 2 * blockDim.x * sizeof(DType)); + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + sum_val0_buf[idx] = sum_val0; + sum_val1_buf[idx] = sum_val1; + } + __syncthreads(); + if (threadIdx.y < offset) { + const int idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_val0 += sum_val0_buf[idx]; + sum_val1 += sum_val1_buf[idx]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + sum_val0_buf[threadIdx.x] = sum_val0; + sum_val1_buf[threadIdx.x] = sum_val1; + } + __syncthreads(); + sum_val0 = sum_val0_buf[threadIdx.x]; + sum_val1 = sum_val1_buf[threadIdx.x]; + } + sum_val0 /= nchannel; + sum_val1 /= nchannel; + // 2. Calculate the gradient as + // out_grad * gamma / std - sum_val0 - (x - mean) / std * sum_val1 + for (int l = tid; l < nchannel; l += nthread) { + DType ele_out_grad = out_grad[bid * nchannel + l]; + DType ele_x = in_data[bid * nchannel + l]; + if (data_addto) { + data_grad[bid * nchannel + l] += + ele_out_grad * gamma[l] * invstd_eps - sum_val0 + - (ele_x - mean) * invstd_eps * sum_val1; + } else { + data_grad[bid * nchannel + l] = + ele_out_grad * gamma[l] * invstd_eps - sum_val0 + - (ele_x - mean) * invstd_eps * sum_val1; + } + } + } +} + +void GetGammaBetaGradKernelParams(const int nbatch, const int nchannel, + dim3* part_grad_block_dim, dim3* part_grad_grid_dim, + dim3* gb_block_dim, dim3* gb_grid_dim, + int* npart) { + *npart = 16; + *part_grad_block_dim = dim3(32, 16); + *part_grad_grid_dim = dim3((nchannel + 32 - 1) / 32, *npart); + *gb_block_dim = dim3(32, *npart); + *gb_grid_dim = dim3((nchannel + 32 - 1) / 32); + CheckLaunchParam(*part_grad_grid_dim, *part_grad_block_dim); + CheckLaunchParam(*gb_grid_dim, *gb_block_dim); +} + +void LayerNormGradGPUContig(const LayerNormParam param, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 5U); + const TBlob out_grad = inputs[0]; + const TBlob in_data = inputs[1]; + const TBlob gamma = inputs[2]; + const TBlob mean_data = inputs[3]; + const TBlob std_data = inputs[4]; + const TBlob data_grad = outputs[0]; + const TBlob gamma_grad = outputs[1]; + const TBlob beta_grad = outputs[2]; + + // Make sure the inputs are contiguous + CHECK_EQ(out_grad.CheckContiguous(), true); + CHECK_EQ(in_data.CheckContiguous(), true); + CHECK_EQ(gamma.CheckContiguous(), true); + CHECK_EQ(mean_data.CheckContiguous(), true); + CHECK_EQ(std_data.CheckContiguous(), true); + int nbatch = in_data.shape_.ProdShape(0, in_data.ndim() - 1); + int nchannel = in_data.shape_[in_data.ndim() - 1]; + int data_grad_req = req[0]; + int gamma_grad_req = req[1]; + int beta_grad_req = req[2]; + CHECK_NE(data_grad_req, kWriteInplace); + CHECK_NE(gamma_grad_req, kWriteInplace); + CHECK_NE(beta_grad_req, kWriteInplace); + Stream *s = ctx.get_stream(); + cudaStream_t stream = Stream::GetStream(s); + + // Calculate the gradient for gamma/beta + CHECK_EQ(gamma_grad.CheckContiguous(), true); + CHECK_EQ(beta_grad.CheckContiguous(), true); + dim3 part_grad_block_dim, part_grad_grid_dim, gb_block_dim, gb_grid_dim; + int npart; + GetGammaBetaGradKernelParams(nbatch, nchannel, &part_grad_block_dim, &part_grad_grid_dim, + &gb_block_dim, &gb_grid_dim, &npart); + if (gamma_grad_req != kNullOp || beta_grad_req != kNullOp) { + MSHADOW_REAL_TYPE_SWITCH(in_data.type_flag_, DType, { + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(2 * npart * nchannel), s); + DType* part_gamma_grad_ptr = workspace.dptr_; + DType* part_beta_grad_ptr = workspace.dptr_ + npart * nchannel; + const int nshared_K1 = 2 * (part_grad_block_dim.x + 1) + * part_grad_block_dim.y * sizeof(DType); + const int nshared_K2 = 2 * gb_block_dim.x * gb_block_dim.y * sizeof(DType); + DType* gamma_grad_ptr = (gamma_grad_req != kNullOp) ? gamma_grad.dptr() : nullptr; + DType* beta_grad_ptr = (beta_grad_req != kNullOp) ? beta_grad.dptr() : nullptr; + LayerNormFusedBackwardKernel_PartGammaBeta + <<>> + (nbatch, nchannel, in_data.dptr(), out_grad.dptr(), + mean_data.dptr(), std_data.dptr(), part_gamma_grad_ptr, part_beta_grad_ptr); + MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_PartGammaBeta); + if (gamma_grad_req == kAddTo && beta_grad_req != kAddTo) { + LayerNormFusedBackwardKernel_GammaBeta + <<>> + (nbatch, nchannel, npart, part_gamma_grad_ptr, part_beta_grad_ptr, + gamma_grad_ptr, beta_grad_ptr); + } else if (gamma_grad_req != kAddTo && beta_grad_req == kAddTo) { + LayerNormFusedBackwardKernel_GammaBeta + <<>> + (nbatch, nchannel, npart, part_gamma_grad_ptr, part_beta_grad_ptr, + gamma_grad_ptr, beta_grad_ptr); + } else if (gamma_grad_req == kAddTo && beta_grad_req == kAddTo) { + LayerNormFusedBackwardKernel_GammaBeta + <<>> + (nbatch, nchannel, npart, part_gamma_grad_ptr, part_beta_grad_ptr, + gamma_grad_ptr, beta_grad_ptr); + } else { + LayerNormFusedBackwardKernel_GammaBeta + <<>> + (nbatch, nchannel, npart, part_gamma_grad_ptr, part_beta_grad_ptr, + gamma_grad_ptr, beta_grad_ptr); + } + }); + MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_GammaBeta); + } + + // Calculate the gradient for data + CHECK_EQ(data_grad.CheckContiguous(), true); + int ngrid_x = (nbatch > kMaxGridDim) ? (nbatch + kBaseGridNum - 1) / kBaseGridNum : nbatch; + int ngrid_y = (nbatch > kMaxGridDim) ? kBaseGridNum : 1; + const dim3 data_grid_dim(ngrid_x, ngrid_y); + int nthread_y; + if (nchannel <= 32) { + nthread_y = 1; + } else if (nchannel <= 128) { + nthread_y = 2; + } else if (nchannel <= 512) { + nthread_y = 4; + } else { + nthread_y = 8; + } + const dim3 data_block_dim(32, nthread_y); + const int LOAD_UNROLL = 4; + if (data_grad_req != kNullOp) { + MSHADOW_REAL_TYPE_SWITCH(in_data.type_flag_, DType, { + int nshared = data_block_dim.y > 1 ? data_block_dim.y * data_block_dim.x * sizeof(DType) : 0; + CheckLaunchParam(data_grid_dim, data_block_dim); + if (data_grad_req == kAddTo) { + LayerNormFusedBackwardKernel_Data + <<>> + (nbatch, nchannel, in_data.dptr(), out_grad.dptr(), mean_data.dptr(), + std_data.dptr(), gamma.dptr(), data_grad.dptr()); + } else { + LayerNormFusedBackwardKernel_Data + <<>> + (nbatch, nchannel, in_data.dptr(), out_grad.dptr(), mean_data.dptr(), + std_data.dptr(), gamma.dptr(), data_grad.dptr()); + } + }); + MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_Data); + } +} + +template<> +void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const LayerNormParam& param = nnvm::get(attrs.parsed); + int axis = param.axis; + if (axis < 0) { + axis += static_cast(inputs[0].ndim()); + } + CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis; + if (axis == inputs[0].ndim() - 1) { + // Try to use the accelerated CUDA kernels + return LayerNormGradGPUContig(param, ctx, inputs, req, outputs); + } + return LayerNormGradComputeGeneral(attrs, ctx, inputs, req, outputs); +} + + NNVM_REGISTER_OP(LayerNorm) .set_attr("FCompute", LayerNormCompute);