From 672af41766c7f311212d16bcbbb86ac4b57a3208 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 11 Jul 2019 10:43:43 -0700 Subject: [PATCH 01/15] Softmax optimization --- src/operator/nn/softmax-inl.h | 176 +++++++++++++++++++++++++++++++++- 1 file changed, 172 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 2c82d839e5ed..428eea023fa9 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -29,6 +29,7 @@ #include #include #include +#include #include "../mxnet_op.h" #include "../operator_common.h" @@ -317,6 +318,155 @@ __global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axi } } +const int softmax_threads_per_block = 512; + +template +__device__ inline T warp_reduce(T value, OP redfun) { + value = redfun(value, __shfl_down_sync(0xffffffff, value, 16)); + value = redfun(value, __shfl_down_sync(0xffffffff, value, 8)); + value = redfun(value, __shfl_down_sync(0xffffffff, value, 4)); + value = redfun(value, __shfl_down_sync(0xffffffff, value, 2)); + value = redfun(value, __shfl_down_sync(0xffffffff, value, 1)); + return value; +} + +template +__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) { + float v = float(value); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 16)); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 8)); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 4)); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 2)); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 1)); + return mshadow::half::half_t(v); +} + +template +__global__ void softmax_compute_kernel2(const DType *in, OType *out, const index_t M, + const double temperature, int rows_per_block, + const index_t total_rows) { + __shared__ AType scratch[softmax_threads_per_block]; + __shared__ LType persistent_storage[20*1024 / sizeof(LType)]; + const int warp_size = 32; + const int threads_per_row = softmax_threads_per_block / rows_per_block; + const int my_local_row = threadIdx.x / threads_per_row; + const int my_row = blockIdx.x * rows_per_block + my_local_row; + if (my_row >= total_rows) return; + const int my_id = threadIdx.x % threads_per_row; + const int entries_per_load = sizeof(LType)/sizeof(DType); + // Due to usage of MSHADOW_TYPE_SWITCH macro we are generating + // kernels where sizeof(LType) may be less than sizeof(DType), + // resulting in entries_per_load being 0. + // This is not a valid combination and is being checked against + // in the launcher code. This switch here is just to silence + // the division by zero warning generated for such invalid cases. + const int row_length = entries_per_load > 0 ? M / entries_per_load : 0; + + const LType * in_aligned = reinterpret_cast(in); + size_t base = my_row * row_length; + + for (index_t i = my_id; i < row_length; i += threads_per_row) { + persistent_storage[my_local_row * row_length + i] = in_aligned[base + i]; + } + DType * row = reinterpret_cast(persistent_storage + my_local_row * row_length); + __syncthreads(); + + DType my_max_value; + red::maximum::SetInitValue(my_max_value); + + for (index_t i = my_id; i < M; i += threads_per_row) { + my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]); + } + scratch[threadIdx.x] = my_max_value; + __syncthreads(); + for (int size = threads_per_row / 2; size >= warp_size; size /= 2) { + if (my_id < size) { + scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + size]); + } + __syncthreads(); + } + if (my_id < warp_size) { + AType my_value = warp_reduce(scratch[threadIdx.x], [](AType x, AType y) { return ::max(x, y); }); + scratch[threadIdx.x] = my_value; + } + __syncthreads(); + DType smax = scratch[threadIdx.x - threadIdx.x % threads_per_row]; + __syncthreads(); + + AType my_sum; + red::sum::SetInitValue(my_sum); + + for (index_t i = my_id; i < M; i += threads_per_row) { + const DType val = negate ? -row[i] : row[i]; + my_sum += static_cast(expf((val - smax) / static_cast(temperature))); + } + scratch[threadIdx.x] = my_sum; + __syncthreads(); + for (int size = threads_per_row / 2; size >= warp_size; size /= 2) { + if (my_id < size) { + scratch[threadIdx.x] += scratch[threadIdx.x + size]; + } + __syncthreads(); + } + if (my_id < warp_size) { + AType my_value = warp_reduce(scratch[threadIdx.x], [](AType x, AType y) { return x + y;}); + scratch[threadIdx.x] = my_value; + } + __syncthreads(); + + AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row]; + __syncthreads(); + + for (index_t i = my_id; i < M; i += threads_per_row) { + const DType val = negate ? -row[i] : row[i]; + row[i] = OP::Map((val - smax)/static_cast(temperature), ssum); + } + __syncthreads(); + + LType * out_aligned = reinterpret_cast(out); + + for (index_t i = my_id; i < row_length; i += threads_per_row) { + out_aligned[base + i] = persistent_storage[my_local_row * row_length + i]; + } +} + +namespace { + +int get_load_type(size_t N) { + if (N % 8 == 0) { + return kFloat64; + } else if (N % 4 == 0) { + return kFloat32; + } else if (N % 2 == 0) { + return kFloat16; + } else { + return kInt8; + } +} + +int get_rows_per_block(size_t N) { + const int warp_size = 32; + // How many read instructions should 1 thread at least do + const int read_instructions = 2; + const int num_threads = (N + read_instructions - 1) / read_instructions; + int num_warps = (num_threads + warp_size - 1) / warp_size; + // num_warps needs to be power of 2 + int used_num_warps = 1; + num_warps = std::min(num_warps, softmax_threads_per_block / warp_size); + int tmp = num_warps; + while (tmp >= 2) { + used_num_warps *= 2; + tmp /= 2; + } + if (used_num_warps < num_warps) { + used_num_warps *= 2; + } + return softmax_threads_per_block / (warp_size * used_num_warps); +} + +} // namespace + template inline void Softmax(Stream *s, DType *in, OType *out, Shape shape, int axis, const double temperature) { @@ -328,10 +478,28 @@ inline void Softmax(Stream *s, DType *in, OType *out, Shape sshape = shape; sshape[axis] = 1; - softmax_compute_kernel - <<::GetStream(s)>>>( - in, out, M, axis, sshape, stride, temperature); - MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); + const int DSize = sizeof(DType); + // Using 20 kB of shared memory for persistent storage in the optimized case + const int max_opt_M = 20*1024/DSize; + if (stride[axis] == 1 && + M <= max_opt_M && + std::is_same::value) { + int ltype = get_load_type(M * sizeof(DType)); + MSHADOW_TYPE_SWITCH(ltype, LType, { + int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType)); + int nblocks = (N + rows_per_block - 1) / rows_per_block; + CHECK_LE(sizeof(DType), sizeof(LType)); + softmax_compute_kernel2 + <<::GetStream(s)>>>( + in, out, M, temperature, rows_per_block, N); + }); + MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel2); + } else { + softmax_compute_kernel + <<::GetStream(s)>>>( + in, out, M, axis, sshape, stride, temperature); + MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); + } } template Date: Mon, 15 Jul 2019 16:02:11 -0700 Subject: [PATCH 02/15] Fix lint --- src/operator/nn/softmax-inl.h | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 428eea023fa9..59b4fc43b1dc 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -332,7 +332,7 @@ __device__ inline T warp_reduce(T value, OP redfun) { template __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) { - float v = float(value); + float v = static_cast(value); v = redfun(v, __shfl_down_sync(0xffffffff, v, 16)); v = redfun(v, __shfl_down_sync(0xffffffff, v, 8)); v = redfun(v, __shfl_down_sync(0xffffffff, v, 4)); @@ -387,7 +387,8 @@ __global__ void softmax_compute_kernel2(const DType *in, OType *out, const index __syncthreads(); } if (my_id < warp_size) { - AType my_value = warp_reduce(scratch[threadIdx.x], [](AType x, AType y) { return ::max(x, y); }); + AType my_value = warp_reduce(scratch[threadIdx.x], + [](AType x, AType y) { return ::max(x, y); }); scratch[threadIdx.x] = my_value; } __syncthreads(); @@ -410,7 +411,8 @@ __global__ void softmax_compute_kernel2(const DType *in, OType *out, const index __syncthreads(); } if (my_id < warp_size) { - AType my_value = warp_reduce(scratch[threadIdx.x], [](AType x, AType y) { return x + y;}); + AType my_value = warp_reduce(scratch[threadIdx.x], + [](AType x, AType y) { return x + y;}); scratch[threadIdx.x] = my_value; } __syncthreads(); From da75ec15b13c78c1ee30939ccaa0881d2f09eac0 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 22 Jul 2019 10:32:04 -0700 Subject: [PATCH 03/15] Unifying softmax with length with regular softmax --- src/operator/nn/softmax-inl.h | 337 ++++++++++++++-------------------- 1 file changed, 134 insertions(+), 203 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 59b4fc43b1dc..531b3ceb3e12 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -65,8 +65,9 @@ struct log_softmax_fwd { }; -template -inline void Softmax(Stream *s, DType *in, OType *out, +template +inline void Softmax(Stream *s, DType *in, OType *out, IType *length, Shape shape, int axis, const DType temperature) { index_t M = shape[axis]; index_t N = shape.Size()/M; @@ -75,100 +76,89 @@ inline void Softmax(Stream *s, DType *in, OType *out, sshape[axis] = 1; index_t sa = stride[axis]; - #pragma omp parallel for - for (index_t i = 0; i < N; ++i) { - index_t base = unravel_dot(i, sshape, stride); - - DType mmax = negate ? -in[base] : in[base]; - DType val; - for (index_t j = 1; j < M; ++j) { - val = negate ? -in[base + j*sa] : in[base + j*sa]; - if (mmax < val) mmax = val; - } - - AType sum = AType(0); - DType in_val; - // By default temperature is 1.0. - // Adding a branch here to save the CPU 'divide-by-1' computation at runtime - if (temperature == 1.0) { - for (index_t j = 0; j < M; ++j) { - in_val = negate ? -in[base + j*sa] : in[base + j*sa]; - sum += std::exp(in_val - mmax); - } - - for (index_t j = 0; j < M; ++j) { - in_val = negate ? -in[base + j*sa] : in[base + j*sa]; - out[base + j*sa] = OP::Map(in_val - mmax, sum); - } - } else { - for (index_t j = 0; j < M; ++j) { - in_val = negate ? -in[base + j*sa] : in[base + j*sa]; - sum += std::exp((in_val - mmax)/temperature); - } + if (length == nullptr) { + #pragma omp parallel for + for (index_t i = 0; i < N; ++i) { + index_t base = unravel_dot(i, sshape, stride); - for (index_t j = 0; j < M; ++j) { - in_val = negate ? -in[base + j*sa] : in[base + j*sa]; - out[base + j*sa] = OP::Map((in_val - mmax)/temperature, sum); + DType mmax = negate ? -in[base] : in[base]; + DType val; + for (index_t j = 1; j < M; ++j) { + val = negate ? -in[base + j*sa] : in[base + j*sa]; + if (mmax < val) mmax = val; } - } - } -} - -template -inline void SoftmaxWithLength(Stream *s, DType *in, OType *out, IType *length, - Shape shape, int axis, const DType temperature) { - index_t M = shape[axis]; - index_t N = shape.Size()/M; - Shape stride = calc_stride(shape); - Shape sshape = shape; - sshape[axis] = 1; - index_t sa = stride[axis]; - #pragma omp parallel for - for (index_t i = 0; i < N; ++i) { - index_t len = static_cast(length[i]); - index_t base = unravel_dot(i, sshape, stride); + AType sum = AType(0); + DType in_val; + // By default temperature is 1.0. + // Adding a branch here to save the CPU 'divide-by-1' computation at runtime + if (temperature == 1.0) { + for (index_t j = 0; j < M; ++j) { + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + sum += std::exp(in_val - mmax); + } - DType mmax = negate ? -in[base] : in[base]; - DType val; - for (index_t j = 1; j < len; ++j) { - val = negate ? -in[base + j*sa] : in[base + j*sa]; - if (mmax < val) mmax = val; - } - for (index_t j = len; j < M; ++j) { - out[base + j*sa] = OType(0.0f); - } + for (index_t j = 0; j < M; ++j) { + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + out[base + j*sa] = OP::Map(in_val - mmax, sum); + } + } else { + for (index_t j = 0; j < M; ++j) { + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + sum += std::exp((in_val - mmax)/temperature); + } - AType sum = AType(0); - DType in_val; - // By default temperature is 1.0. - // Adding a branch here to save the CPU 'divide-by-1' computation at runtime - if (temperature == 1.0) { - for (index_t j = 0; j < len; ++j) { - in_val = negate ? -in[base + j*sa] : in[base + j*sa]; - sum += std::exp(in_val - mmax); + for (index_t j = 0; j < M; ++j) { + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + out[base + j*sa] = OP::Map((in_val - mmax)/temperature, sum); + } } - - for (index_t j = 0; j < len; ++j) { - in_val = negate ? -in[base + j*sa] : in[base + j*sa]; - out[base + j*sa] = OP::Map(in_val - mmax, sum); + } + } else { + #pragma omp parallel for + for (index_t i = 0; i < N; ++i) { + index_t len = static_cast(length[i]); + index_t base = unravel_dot(i, sshape, stride); + + DType mmax = negate ? -in[base] : in[base]; + DType val; + for (index_t j = 1; j < len; ++j) { + val = negate ? -in[base + j*sa] : in[base + j*sa]; + if (mmax < val) mmax = val; } - } else { - for (index_t j = 0; j < len; ++j) { - in_val = negate ? -in[base + j*sa] : in[base + j*sa]; - sum += std::exp((in_val - mmax)/temperature); + for (index_t j = len; j < M; ++j) { + out[base + j*sa] = OType(0.0f); } - for (index_t j = 0; j < len; ++j) { - in_val = negate ? -in[base + j*sa] : in[base + j*sa]; - out[base + j*sa] = OP::Map((in_val - mmax)/temperature, sum); + AType sum = AType(0); + DType in_val; + // By default temperature is 1.0. + // Adding a branch here to save the CPU 'divide-by-1' computation at runtime + if (temperature == 1.0) { + for (index_t j = 0; j < len; ++j) { + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + sum += std::exp(in_val - mmax); + } + + for (index_t j = 0; j < len; ++j) { + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + out[base + j*sa] = OP::Map(in_val - mmax, sum); + } + } else { + for (index_t j = 0; j < len; ++j) { + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + sum += std::exp((in_val - mmax)/temperature); + } + + for (index_t j = 0; j < len; ++j) { + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + out[base + j*sa] = OP::Map((in_val - mmax)/temperature, sum); + } } } } } - struct softmax_bwd { template MSHADOW_XINLINE static AType Map(DType ograd, DType out, AType sum) { @@ -280,18 +270,19 @@ inline void SoftmaxWithLengthGrad(Stream *s, OType *out, OType *ograd, #ifdef __CUDACC__ template -__global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axis, - Shape sshape, Shape stride, - const double temperature) { + typename DType, typename OType, typename IType> +__global__ void softmax_compute_kernel(DType *in, OType *out, IType *length, + index_t M, int axis, Shape sshape, + Shape stride, const double temperature) { const unsigned x_size = 1 << x_bits; __shared__ AType smem[x_size]; index_t sa = stride[axis]; index_t base = unravel_dot(blockIdx.x, sshape, stride); index_t x = threadIdx.x; + const index_t len = length == nullptr ? M : static_cast(length[blockIdx.x]); red::maximum::SetInitValue(smem[x]); - for (index_t i = x; i < M; i += x_size) { + for (index_t i = x; i < len; i += x_size) { smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]); } __syncthreads(); @@ -302,7 +293,7 @@ __global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axi red::sum::SetInitValue(smem[x]); DType val; - for (index_t i = x; i < M; i += x_size) { + for (index_t i = x; i < len; i += x_size) { val = negate ? -in[base + i*sa]:in[base + i*sa]; smem[x] += static_cast(expf((val - smax) / static_cast(temperature))); } @@ -314,7 +305,8 @@ __global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axi for (index_t i = x; i < M; i += x_size) { val = negate ? -in[base + i*sa] : in[base + i*sa]; - out[base + i*sa] = OP::Map((val - smax)/static_cast(temperature), ssum); + out[base + i*sa] = + (i < len) ? OType(OP::Map((val - smax)/static_cast(temperature), ssum)) : OType(0.0f); } } @@ -342,12 +334,12 @@ __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, } template -__global__ void softmax_compute_kernel2(const DType *in, OType *out, const index_t M, - const double temperature, int rows_per_block, - const index_t total_rows) { + typename DType, typename OType, typename IType> +__global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, IType *length, + const index_t M, const double temperature, + const int rows_per_block, const index_t total_rows) { __shared__ AType scratch[softmax_threads_per_block]; - __shared__ LType persistent_storage[20*1024 / sizeof(LType)]; + __shared__ LType persistent_storage[20 * 1024 / sizeof(LType)]; const int warp_size = 32; const int threads_per_row = softmax_threads_per_block / rows_per_block; const int my_local_row = threadIdx.x / threads_per_row; @@ -355,6 +347,7 @@ __global__ void softmax_compute_kernel2(const DType *in, OType *out, const index if (my_row >= total_rows) return; const int my_id = threadIdx.x % threads_per_row; const int entries_per_load = sizeof(LType)/sizeof(DType); + const index_t len = length == nullptr ? M : static_cast(length[my_row]); // Due to usage of MSHADOW_TYPE_SWITCH macro we are generating // kernels where sizeof(LType) may be less than sizeof(DType), // resulting in entries_per_load being 0. @@ -375,7 +368,7 @@ __global__ void softmax_compute_kernel2(const DType *in, OType *out, const index DType my_max_value; red::maximum::SetInitValue(my_max_value); - for (index_t i = my_id; i < M; i += threads_per_row) { + for (index_t i = my_id; i < len; i += threads_per_row) { my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]); } scratch[threadIdx.x] = my_max_value; @@ -398,7 +391,7 @@ __global__ void softmax_compute_kernel2(const DType *in, OType *out, const index AType my_sum; red::sum::SetInitValue(my_sum); - for (index_t i = my_id; i < M; i += threads_per_row) { + for (index_t i = my_id; i < len; i += threads_per_row) { const DType val = negate ? -row[i] : row[i]; my_sum += static_cast(expf((val - smax) / static_cast(temperature))); } @@ -422,7 +415,7 @@ __global__ void softmax_compute_kernel2(const DType *in, OType *out, const index for (index_t i = my_id; i < M; i += threads_per_row) { const DType val = negate ? -row[i] : row[i]; - row[i] = OP::Map((val - smax)/static_cast(temperature), ssum); + row[i] = (i < len) ? DType(OP::Map((val - smax)/static_cast(temperature), ssum)) : DType(0.0f); } __syncthreads(); @@ -469,8 +462,9 @@ int get_rows_per_block(size_t N) { } // namespace -template -inline void Softmax(Stream *s, DType *in, OType *out, +template +inline void Softmax(Stream *s, DType *in, OType *out, IType *length, Shape shape, int axis, const double temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; @@ -480,9 +474,9 @@ inline void Softmax(Stream *s, DType *in, OType *out, Shape sshape = shape; sshape[axis] = 1; - const int DSize = sizeof(DType); + const size_t DSize = sizeof(DType); // Using 20 kB of shared memory for persistent storage in the optimized case - const int max_opt_M = 20*1024/DSize; + const size_t max_opt_M = 20 * 1024 / DSize; if (stride[axis] == 1 && M <= max_opt_M && std::is_same::value) { @@ -491,79 +485,19 @@ inline void Softmax(Stream *s, DType *in, OType *out, int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType)); int nblocks = (N + rows_per_block - 1) / rows_per_block; CHECK_LE(sizeof(DType), sizeof(LType)); - softmax_compute_kernel2 + softmax_stride1_compute_kernel <<::GetStream(s)>>>( - in, out, M, temperature, rows_per_block, N); + in, out, length, M, temperature, rows_per_block, N); }); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel2); } else { softmax_compute_kernel <<::GetStream(s)>>>( - in, out, M, axis, sshape, stride, temperature); + in, out, length, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); } } -template -__global__ void softmax_with_length_kernel(DType *in, OType *out, IType *length, - index_t M, int axis, Shape sshape, - Shape stride, const double temperature) { - const unsigned x_size = 1 << x_bits; - __shared__ AType smem[x_size]; - index_t sa = stride[axis]; - index_t base = unravel_dot(blockIdx.x, sshape, stride); - index_t x = threadIdx.x; - index_t len = static_cast(length[blockIdx.x]); - - red::maximum::SetInitValue(smem[x]); - for (index_t i = x; i < len; i += x_size) { - smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]); - } - __syncthreads(); - cuda::Reduce1D(smem); - __syncthreads(); - DType smax = smem[0]; - __syncthreads(); - - red::sum::SetInitValue(smem[x]); - DType val; - for (index_t i = x; i < len; i += x_size) { - val = negate ? -in[base + i*sa]:in[base + i*sa]; - smem[x] += static_cast(expf((val - smax) / static_cast(temperature))); - } - __syncthreads(); - cuda::Reduce1D(smem); - __syncthreads(); - AType ssum = smem[0]; - __syncthreads(); - - for (index_t i = x; i < M; i += x_size) { - val = negate ? -in[base + i*sa] : in[base + i*sa]; - out[base + i*sa] = - (i < len) ? OType(OP::Map((val - smax)/static_cast(temperature), ssum)) : OType(0.0f); - } -} - -template -inline void SoftmaxWithLength(Stream *s, DType *in, OType *out, IType *length, - Shape shape, int axis, const double temperature) { - const int x_bits = 7; - const int x_size = 1 << x_bits; - index_t M = shape[axis]; - index_t N = shape.Size()/M; - Shape stride = calc_stride(shape); - Shape sshape = shape; - sshape[axis] = 1; - - softmax_with_length_kernel - <<::GetStream(s)>>>( - in, out, length, M, axis, sshape, stride, temperature); - MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); -} - - template __global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad, @@ -734,8 +668,9 @@ static inline bool SoftmaxOpShape(const nnvm::NodeAttrs& attrs, mxnet::TShape& dshape = in_attrs->at(0); mxnet::TShape tmp_shape((dshape.ndim() == 1) ? 1U : dshape.ndim() - 1, 1); int j = 0; + int axis = param.axis != -1 ? param.axis : dshape.ndim() - 1; for (int i = 0; i < dshape.ndim(); ++i) { - if (i != param.axis) { + if (i != axis) { tmp_shape[j++] = dshape[i]; } } @@ -864,47 +799,43 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { - if (!param.use_length.value()) { - if (safe_acc) { - if (shape.ndim() == 2) { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis, - static_cast(temperature)); - } else { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis, - static_cast(temperature)); - } - } else { - if (shape.ndim() == 2) { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis, - static_cast(temperature)); - } else { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis, - static_cast(temperature)); + int type = kInt32; + if (param.use_length.value()) { + CHECK(inputs.size() > 1) + << "Mask needs to be provided when using softmax with use_length=True."; + type = inputs[1].type_flag_; + } + MXNET_INT_TYPE_SWITCH(type, IType, { + IType* mask_ptr = nullptr; + if (param.use_length.value()) { + mask_ptr = inputs[1].dptr(); } - } - } else { - MXNET_INT_TYPE_SWITCH(inputs[1].type_flag_, IType, { - if (shape.ndim() == 2) { - SoftmaxWithLength( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), inputs[1].dptr(), - shape.get<2>(), axis, static_cast(temperature)); + if (safe_acc) { + if (shape.ndim() == 2) { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), mask_ptr, shape.get<2>(), + axis, static_cast(temperature)); + } else { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), mask_ptr, shape.get<3>(), + axis, static_cast(temperature)); + } } else { - SoftmaxWithLength( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), inputs[1].dptr(), - shape.get<3>(), axis, static_cast(temperature)); + if (shape.ndim() == 2) { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), mask_ptr, shape.get<2>(), + axis, static_cast(temperature)); + } else { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), mask_ptr, shape.get<3>(), + axis, static_cast(temperature)); + } } - }); - } + }); }); }); } From 40c0d004fc6523356055718c82096d35aca048be Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 23 Jul 2019 09:05:31 -0700 Subject: [PATCH 04/15] Fix lint --- src/operator/nn/softmax-inl.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 531b3ceb3e12..ad02ec9da35e 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -415,7 +415,8 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp for (index_t i = my_id; i < M; i += threads_per_row) { const DType val = negate ? -row[i] : row[i]; - row[i] = (i < len) ? DType(OP::Map((val - smax)/static_cast(temperature), ssum)) : DType(0.0f); + row[i] = (i < len) ? DType(OP::Map((val - smax)/static_cast(temperature), ssum)) : + DType(0.0f); } __syncthreads(); From 9c33c4b4fd40256aba881338e59cf7c12aeaad31 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 23 Jul 2019 14:29:02 -0700 Subject: [PATCH 05/15] Fixes from review --- src/common/cuda_utils.cc | 52 +++++++++++++++++++++++++++++++++++ src/common/cuda_utils.h | 37 +++++++++++++++++++++++-- src/operator/nn/softmax-inl.h | 42 ++++------------------------ 3 files changed, 91 insertions(+), 40 deletions(-) create mode 100644 src/common/cuda_utils.cc diff --git a/src/common/cuda_utils.cc b/src/common/cuda_utils.cc new file mode 100644 index 000000000000..cbf46508cdbf --- /dev/null +++ b/src/common/cuda_utils.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file cuda_utils.cc + * \brief Common CUDA utilities. + */ + +#include +#include +#include "cuda_utils.h" + +#if MXNET_USE_CUDA + +namespace mxnet { +namespace common { +namespace cuda { + +int get_load_type(size_t N) { + using namespace mshadow; + if (N % 8 == 0) { + return kFloat64; + } else if (N % 4 == 0) { + return kFloat32; + } else if (N % 2 == 0) { + return kFloat16; + } else { + return kInt8; + } +} +} // namespace cuda +} // namespace common +} // namespace mxnet + +#endif // MXNET_USE_CUDA diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index acc8d5fac6df..f16607d8b716 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -20,7 +20,7 @@ /*! * Copyright (c) 2015 by Contributors * \file cuda_utils.h - * \brief CUDA debugging utilities. + * \brief Common CUDA utilities. */ #ifndef MXNET_COMMON_CUDA_UTILS_H_ #define MXNET_COMMON_CUDA_UTILS_H_ @@ -326,6 +326,15 @@ class DeviceStore { bool restore_; }; +/*! \brief Get the largest datatype suitable to read + * requested number of bytes. + * + * \input Number of bytes to be read + * \return mshadow representation of type that could + * be used for reading + */ +int get_load_type(size_t N); + } // namespace cuda } // namespace common } // namespace mxnet @@ -550,7 +559,7 @@ static inline __device__ void atomicAdd(double *address, double val) { // Overload atomicAdd for half precision // Taken from: // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh -#if defined(__CUDA_ARCH__) +#ifdef __CUDACC__ static inline __device__ void atomicAdd(mshadow::half::half_t *address, mshadow::half::half_t val) { unsigned int *address_as_ui = @@ -615,6 +624,28 @@ __device__ inline DType ldg(const DType* address) { return *address; #endif } -#endif + +template +__device__ inline T warp_reduce(T value, OP redfun) { + value = redfun(value, __shfl_down_sync(0xffffffff, value, 16)); + value = redfun(value, __shfl_down_sync(0xffffffff, value, 8)); + value = redfun(value, __shfl_down_sync(0xffffffff, value, 4)); + value = redfun(value, __shfl_down_sync(0xffffffff, value, 2)); + value = redfun(value, __shfl_down_sync(0xffffffff, value, 1)); + return value; +} + +template +__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) { + float v = static_cast(value); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 16)); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 8)); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 4)); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 2)); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 1)); + return mshadow::half::half_t(v); +} + +#endif // __CUDACC__ #endif // MXNET_COMMON_CUDA_UTILS_H_ diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index ad02ec9da35e..0ff42f4d7d63 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -34,6 +34,7 @@ #include "../mxnet_op.h" #include "../operator_common.h" #include "../tensor/broadcast_reduce_op.h" +#include "../../common/cuda_utils.h" namespace mxnet { namespace op { @@ -312,27 +313,6 @@ __global__ void softmax_compute_kernel(DType *in, OType *out, IType *length, const int softmax_threads_per_block = 512; -template -__device__ inline T warp_reduce(T value, OP redfun) { - value = redfun(value, __shfl_down_sync(0xffffffff, value, 16)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 8)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 4)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 2)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 1)); - return value; -} - -template -__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) { - float v = static_cast(value); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 16)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 8)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 4)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 2)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 1)); - return mshadow::half::half_t(v); -} - template __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, IType *length, @@ -356,7 +336,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp // the division by zero warning generated for such invalid cases. const int row_length = entries_per_load > 0 ? M / entries_per_load : 0; - const LType * in_aligned = reinterpret_cast(in); + const LType* in_aligned = reinterpret_cast(in); size_t base = my_row * row_length; for (index_t i = my_id; i < row_length; i += threads_per_row) { @@ -420,7 +400,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp } __syncthreads(); - LType * out_aligned = reinterpret_cast(out); + LType* out_aligned = reinterpret_cast(out); for (index_t i = my_id; i < row_length; i += threads_per_row) { out_aligned[base + i] = persistent_storage[my_local_row * row_length + i]; @@ -429,18 +409,6 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp namespace { -int get_load_type(size_t N) { - if (N % 8 == 0) { - return kFloat64; - } else if (N % 4 == 0) { - return kFloat32; - } else if (N % 2 == 0) { - return kFloat16; - } else { - return kInt8; - } -} - int get_rows_per_block(size_t N) { const int warp_size = 32; // How many read instructions should 1 thread at least do @@ -479,9 +447,9 @@ inline void Softmax(Stream *s, DType *in, OType *out, IType *length, // Using 20 kB of shared memory for persistent storage in the optimized case const size_t max_opt_M = 20 * 1024 / DSize; if (stride[axis] == 1 && - M <= max_opt_M && + static_cast(M) <= max_opt_M && std::is_same::value) { - int ltype = get_load_type(M * sizeof(DType)); + int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType)); MSHADOW_TYPE_SWITCH(ltype, LType, { int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType)); int nblocks = (N + rows_per_block - 1) / rows_per_block; From 99793abf76abdd290c15f03d9e1758a9939cb7cd Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 25 Jul 2019 10:22:57 -0700 Subject: [PATCH 06/15] Making less templated kernels --- src/common/cuda_utils.cc | 2 +- src/operator/mxnet_op.h | 30 ++++++++++++++++++++++++++++++ src/operator/nn/softmax-inl.h | 4 ++-- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/common/cuda_utils.cc b/src/common/cuda_utils.cc index cbf46508cdbf..2f8835e14d6b 100644 --- a/src/common/cuda_utils.cc +++ b/src/common/cuda_utils.cc @@ -42,7 +42,7 @@ int get_load_type(size_t N) { } else if (N % 2 == 0) { return kFloat16; } else { - return kInt8; + return kUint8; } } } // namespace cuda diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 52788f697f11..ae98952bb7bf 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -414,6 +414,36 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MXNET_LOAD_TYPE_SWITCH(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Invalid loading enum type " << type; \ + } + /*! * \brief assign the val to out according * to request in Kernel::Launch diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 0ff42f4d7d63..249e5d9ef957 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -450,7 +450,7 @@ inline void Softmax(Stream *s, DType *in, OType *out, IType *length, static_cast(M) <= max_opt_M && std::is_same::value) { int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType)); - MSHADOW_TYPE_SWITCH(ltype, LType, { + MXNET_LOAD_TYPE_SWITCH(ltype, LType, { int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType)); int nblocks = (N + rows_per_block - 1) / rows_per_block; CHECK_LE(sizeof(DType), sizeof(LType)); @@ -458,7 +458,7 @@ inline void Softmax(Stream *s, DType *in, OType *out, IType *length, <<::GetStream(s)>>>( in, out, length, M, temperature, rows_per_block, N); }); - MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel2); + MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_compute_kernel); } else { softmax_compute_kernel <<::GetStream(s)>>>( From fea584dbc52bf2206962e2c07c82c7e920a39eae Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 16 Aug 2019 16:28:21 -0700 Subject: [PATCH 07/15] Unifying softmaxgrad and softmaxwithlengthgrad --- src/operator/nn/softmax-inl.h | 235 ++++++++++++---------------------- 1 file changed, 85 insertions(+), 150 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 249e5d9ef957..4fc752b9dc4c 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -180,12 +180,11 @@ struct log_softmax_bwd { } }; - template + typename AType, typename DType, typename OType, typename IType, int ndim> inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, - DType *igrad, Shape shape, int axis, - const DType temperature) { + DType *igrad, IType *length, Shape shape, + int axis, const DType temperature) { index_t M = shape[axis]; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); @@ -193,76 +192,65 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, sshape[axis] = 1; index_t sa = stride[axis]; - #pragma omp parallel for - for (index_t i = 0; i < N; ++i) { - index_t base = unravel_dot(i, sshape, stride); - - AType sum = AType(0); - for (index_t j = 0; j < M; ++j) { - sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); - } + if (length != nullptr) { + #pragma omp parallel for + for (index_t i = 0; i < N; ++i) { + index_t base = unravel_dot(i, sshape, stride); + index_t len = static_cast(length[i]); - // By default temperature is 1.0. - // Adding a branch here to save the CPU 'divide-by-1' computation at runtime - DType final_result; - if (temperature == 1.0) { - for (index_t j = 0; j < M; ++j) { - final_result = negate ? - -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) : - OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); - KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); - } - } else { - for (index_t j = 0; j < M; ++j) { - final_result = negate ? - -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature : - OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature; - KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); + AType sum = AType(0); + for (index_t j = 0; j < len; ++j) { + sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); } - } - } -} -template -inline void SoftmaxWithLengthGrad(Stream *s, OType *out, OType *ograd, - DType *igrad, IType *length, Shape shape, - int axis, const DType temperature) { - index_t M = shape[axis]; - index_t N = shape.Size()/M; - Shape stride = calc_stride(shape); - Shape sshape = shape; - sshape[axis] = 1; - index_t sa = stride[axis]; - - #pragma omp parallel for - for (index_t i = 0; i < N; ++i) { - index_t base = unravel_dot(i, sshape, stride); - index_t len = static_cast(length[i]); - - AType sum = AType(0); - for (index_t j = 0; j < len; ++j) { - sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); + // By default temperature is 1.0. + // Adding a branch here to save the CPU 'divide-by-1' computation at runtime + DType final_result; + if (temperature == 1.0) { + for (index_t j = 0; j < M; ++j) { + final_result = negate ? + -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) : + OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); + final_result = (j < len) ? final_result : DType(0.0f); + KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); + } + } else { + for (index_t j = 0; j < M; ++j) { + final_result = negate ? + -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature : + OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature; + final_result = (j < len) ? final_result : DType(0.0f); + KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); + } + } } + } else { + #pragma omp parallel for + for (index_t i = 0; i < N; ++i) { + index_t base = unravel_dot(i, sshape, stride); - // By default temperature is 1.0. - // Adding a branch here to save the CPU 'divide-by-1' computation at runtime - DType final_result; - if (temperature == 1.0) { + AType sum = AType(0); for (index_t j = 0; j < M; ++j) { - final_result = negate ? - -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) : - OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); - final_result = (j < len) ? final_result : DType(0.0f); - KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); + sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); } - } else { - for (index_t j = 0; j < M; ++j) { - final_result = negate ? - -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature : - OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature; - final_result = (j < len) ? final_result : DType(0.0f); - KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); + + // By default temperature is 1.0. + // Adding a branch here to save the CPU 'divide-by-1' computation at runtime + DType final_result; + if (temperature == 1.0) { + for (index_t j = 0; j < M; ++j) { + final_result = negate ? + -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) : + OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); + KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); + } + } else { + for (index_t j = 0; j < M; ++j) { + final_result = negate ? + -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature : + OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature; + KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); + } } } } @@ -467,60 +455,9 @@ inline void Softmax(Stream *s, DType *in, OType *out, IType *length, } } -template -__global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad, - index_t M, int axis, Shape sshape, - Shape stride, const double temperature) { - const unsigned x_size = 1 << x_bits; - __shared__ AType smem[x_size]; - index_t sa = stride[axis]; - index_t base = unravel_dot(blockIdx.x, sshape, stride); - index_t x = threadIdx.x; - - red::sum::SetInitValue(smem[x]); - for (index_t i = x; i < M; i += x_size) { - smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]); - } - __syncthreads(); - cuda::Reduce1D(smem); - __syncthreads(); - AType ssum = smem[0]; - __syncthreads(); - - DType final_result; - for (index_t i = x; i < M; i += x_size) { - final_result = - negate ? - -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) : - OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum); - KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result / static_cast(temperature)); - } -} - - -template -inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, - DType *igrad, Shape shape, int axis, - const double temperature) { - const int x_bits = 7; - const int x_size = 1 << x_bits; - index_t M = shape[axis]; - index_t N = shape.Size()/M; - Shape stride = calc_stride(shape); - Shape sshape = shape; - sshape[axis] = 1; - - softmax_gradient_kernel - <<::GetStream(s)>>>( - out, ograd, igrad, M, axis, sshape, stride, temperature); - MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel); -} - template -__global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType *igrad, +__global__ void softmax_grad_kernel(OType *out, OType *ograd, DType *igrad, IType *length, index_t M, int axis, Shape sshape, Shape stride, const double temperature) { @@ -529,7 +466,7 @@ __global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType index_t sa = stride[axis]; index_t base = unravel_dot(blockIdx.x, sshape, stride); index_t x = threadIdx.x; - index_t len = static_cast(length[blockIdx.x]); + index_t len = length != nullptr ? static_cast(length[blockIdx.x]) : M; red::sum::SetInitValue(smem[x]); for (index_t i = x; i < len; i += x_size) { @@ -552,12 +489,11 @@ __global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType } } - template -inline void SoftmaxWithLengthGrad(Stream *s, OType *out, OType *ograd, - DType *igrad, IType *length, Shape shape, int axis, - const double temperature) { +inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, + DType *igrad, IType *length, Shape shape, int axis, + const double temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; @@ -566,10 +502,10 @@ inline void SoftmaxWithLengthGrad(Stream *s, OType *out, OType *ograd, Shape sshape = shape; sshape[axis] = 1; - softmax_with_length_grad_kernel + softmax_grad_kernel <<::GetStream(s)>>>( out, ograd, igrad, length, M, axis, sshape, stride, temperature); - MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_with_length_grad_kernel); + MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_grad_kernel); } #endif @@ -817,7 +753,16 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { using namespace mxnet_op; + if (softmax_use_length(attrs)) { + MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, { + if (req[1] != kNullOp) { + mxnet_op::Kernel::Launch( + ctx.get_stream(), outputs[1].Size(), outputs[1].dptr()); + } + }); + } if (req[0] == kNullOp) return; + const int itype = softmax_use_length(attrs) ? inputs[2].type_flag_ : kInt32; const SoftmaxParam& param = nnvm::get(attrs.parsed); int axis = CheckAxis(param.axis, inputs[0].ndim()); const double temperature = param.temperature.has_value() ? @@ -831,51 +776,41 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - if (!softmax_use_length(attrs)) { + MXNET_INT_TYPE_SWITCH(itype, IType, { + IType * length_ptr = nullptr; + if (softmax_use_length(attrs)) { + length_ptr = inputs[2].dptr(); + } if (safe_acc) { if (shape.ndim() == 2) { SoftmaxGrad( ctx.get_stream(), inputs[out_idx].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis, static_cast(temperature)); + length_ptr, shape.get<2>(), axis, + static_cast(temperature)); } else { SoftmaxGrad( ctx.get_stream(), inputs[out_idx].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis, static_cast(temperature)); + length_ptr, shape.get<3>(), axis, + static_cast(temperature)); } } else { if (shape.ndim() == 2) { SoftmaxGrad( ctx.get_stream(), inputs[out_idx].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis, static_cast(temperature)); + length_ptr, shape.get<2>(), axis, + static_cast(temperature)); } else { SoftmaxGrad( ctx.get_stream(), inputs[out_idx].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis, static_cast(temperature)); + length_ptr, shape.get<3>(), axis, + static_cast(temperature)); } } - } else { - MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, { - if (req[1] != kNullOp) { - mxnet_op::Kernel::Launch( - ctx.get_stream(), outputs[1].Size(), outputs[1].dptr()); - } - if (shape.ndim() == 2) { - SoftmaxWithLengthGrad( - ctx.get_stream(), inputs[out_idx].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - inputs[2].dptr(), shape.get<2>(), axis, static_cast(temperature)); - } else { - SoftmaxWithLengthGrad( - ctx.get_stream(), inputs[out_idx].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - inputs[2].dptr(), shape.get<3>(), axis, static_cast(temperature)); - } - }); - } + }); }); }); }); From 77d52fdfa2c80137fbb72bf56cfb7e2a36fabdd7 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 19 Aug 2019 13:56:14 -0700 Subject: [PATCH 08/15] Better gradient of softmax --- src/operator/nn/softmax-inl.h | 116 ++++++++++++++++++++++++++++++++-- 1 file changed, 109 insertions(+), 7 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 4fc752b9dc4c..0d462e8a07f3 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -455,12 +455,94 @@ inline void Softmax(Stream *s, DType *in, OType *out, IType *length, } } +template +__global__ void softmax_stride1_grad_kernel(const OType *out, const OType *ograd, + DType *igrad, const IType *length, + const index_t M, + const double temperature, + const int rows_per_block, + const index_t total_rows) { + __shared__ AType scratch[softmax_threads_per_block]; + __shared__ LType persistent_storage[20 * 1024 / sizeof(LType)]; + const int warp_size = 32; + const int threads_per_row = softmax_threads_per_block / rows_per_block; + const int my_local_row = threadIdx.x / threads_per_row; + const int my_row = blockIdx.x * rows_per_block + my_local_row; + if (my_row >= total_rows) return; + const int my_id = threadIdx.x % threads_per_row; + const int entries_per_load = sizeof(LType)/sizeof(DType); + const index_t len = length == nullptr ? M : static_cast(length[my_row]); + // Due to usage of MSHADOW_TYPE_SWITCH macro we are generating + // kernels where sizeof(LType) may be less than sizeof(DType), + // resulting in entries_per_load being 0. + // This is not a valid combination and is being checked against + // in the launcher code. This switch here is just to silence + // the division by zero warning generated for such invalid cases. + const int row_length = entries_per_load > 0 ? M / entries_per_load : 0; + + const LType* out_aligned = reinterpret_cast(out); + const LType* ograd_aligned = reinterpret_cast(ograd); + size_t base = my_row * row_length; + + for (index_t i = my_id; i < row_length; i += threads_per_row) { + persistent_storage[my_local_row * row_length * 2 + i] = out_aligned[base + i]; + persistent_storage[my_local_row * row_length * 2 + row_length + i] = ograd_aligned[base + i]; + } + DType * row = reinterpret_cast(persistent_storage + my_local_row * row_length * 2); + __syncthreads(); + + AType my_sum_value; + red::sum::SetInitValue(my_sum_value); + + for (index_t i = my_id; i < len; i += threads_per_row) { + my_sum_value += OP1::Map(row[i + M], row[i]); + } + scratch[threadIdx.x] = my_sum_value; + __syncthreads(); + for (int size = threads_per_row / 2; size >= warp_size; size /= 2) { + if (my_id < size) { + scratch[threadIdx.x] = scratch[threadIdx.x] + scratch[threadIdx.x + size]; + } + __syncthreads(); + } + if (my_id < warp_size) { + AType my_value = warp_reduce(scratch[threadIdx.x], + [](AType x, AType y) { return x + y; }); + scratch[threadIdx.x] = my_value; + } + __syncthreads(); + AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row]; + __syncthreads(); + + for (index_t i = my_id; i < M; i += threads_per_row) { + const DType val = + negate ? + -OP2::Map(row[i + M], row[i], ssum) : + OP2::Map(row[i + M], row[i], ssum); + row[i] = (i < len) ? DType(val / static_cast(temperature)) : + DType(0.0f); + if (Req == kAddTo) { + row[i] += igrad[my_row * M + i]; + } + } + if (Req == kAddTo) { + __syncthreads(); + } + + LType* igrad_aligned = reinterpret_cast(igrad); + + for (index_t i = my_id; i < row_length; i += threads_per_row) { + igrad_aligned[base + i] = persistent_storage[my_local_row * row_length * 2 + i]; + } +} + template __global__ void softmax_grad_kernel(OType *out, OType *ograd, DType *igrad, - IType *length, index_t M, int axis, - Shape sshape, Shape stride, - const double temperature) { + const IType *length, index_t M, int axis, + Shape sshape, Shape stride, + const double temperature) { const unsigned x_size = 1 << x_bits; __shared__ AType smem[x_size]; index_t sa = stride[axis]; @@ -502,10 +584,30 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, Shape sshape = shape; sshape[axis] = 1; - softmax_grad_kernel - <<::GetStream(s)>>>( - out, ograd, igrad, length, M, axis, sshape, stride, temperature); - MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_grad_kernel); + const size_t DSize = sizeof(DType); + // Using 20 kB of shared memory for persistent storage in the optimized case + // Need to store both out and ograd, so M can be only half compared to + // forward pass. + const size_t max_opt_M = 20 * 1024 / DSize / 2; + if (stride[axis] == 1 && + static_cast(M) <= max_opt_M && + std::is_same::value) { + int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType)); + MXNET_LOAD_TYPE_SWITCH(ltype, LType, { + int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType)); + int nblocks = (N + rows_per_block - 1) / rows_per_block; + CHECK_LE(sizeof(DType), sizeof(LType)); + softmax_stride1_grad_kernel + <<::GetStream(s)>>>( + out, ograd, igrad, length, M, temperature, rows_per_block, N); + }); + MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_grad_kernel); + } else { + softmax_grad_kernel + <<::GetStream(s)>>>( + out, ograd, igrad, length, M, axis, sshape, stride, temperature); + MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_grad_kernel); + } } #endif From 76383408786b863041600cbf68a982dbad9a1be6 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 19 Aug 2019 14:49:50 -0700 Subject: [PATCH 09/15] Dividing softmax.cc into multiple files --- src/operator/nn/log_softmax.cc | 77 ++++++++++++++++++++++++++ src/operator/nn/log_softmax.cu | 39 ++++++++++++++ src/operator/nn/softmax.cc | 98 ---------------------------------- src/operator/nn/softmax.cu | 18 +------ src/operator/nn/softmin.cc | 89 ++++++++++++++++++++++++++++++ src/operator/nn/softmin.cu | 39 ++++++++++++++ 6 files changed, 246 insertions(+), 114 deletions(-) create mode 100644 src/operator/nn/log_softmax.cc create mode 100644 src/operator/nn/log_softmax.cu create mode 100644 src/operator/nn/softmin.cc create mode 100644 src/operator/nn/softmin.cu diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc new file mode 100644 index 000000000000..a2fd1198e59c --- /dev/null +++ b/src/operator/nn/log_softmax.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file log_softmax.cc + * \brief CPU Implementation of log_softmax + */ +#include "./softmax-inl.h" +#include "../tensor/elemwise_unary_op.h" +#include "../tensor/elemwise_binary_op.h" +#include "../operator_common.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(log_softmax) +.describe(R"code(Computes the log softmax of the input. +This is equivalent to computing softmax followed by log. + +Examples:: + + >>> x = mx.nd.array([1, 2, .1]) + >>> mx.nd.log_softmax(x).asnumpy() + array([-1.41702998, -0.41702995, -2.31702995], dtype=float32) + + >>> x = mx.nd.array( [[1, 2, .1],[.1, 2, 1]] ) + >>> mx.nd.log_softmax(x, axis=0).asnumpy() + array([[-0.34115392, -0.69314718, -1.24115396], + [-1.24115396, -0.69314718, -0.34115392]], dtype=float32) + + +)code") +.set_attr_parser(ParamParser) +.set_attr("FCompute", SoftmaxCompute) +.set_attr("FGradient", SoftmaxFGradient{"_backward_log_softmax"}) +.set_attr("FInferType", SoftmaxOpType) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "The input array.") +.add_arguments(SoftmaxParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_log_softmax) +.set_num_inputs(SoftmaxGradOpNumInputs) +.set_num_outputs(1) +.set_attr("FListInputNames", SoftmaxGradOpInputNames) +.set_attr("FInferShape", SoftmaxGradOpShape) +.set_attr("FInferType", SoftmaxGradOpType) +.set_attr("FInplaceOption", SoftmaxGradOpInplaceOption) +.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments") +.set_attr_parser(ParamParser) +.set_attr("FCompute", SoftmaxGradCompute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/nn/log_softmax.cu b/src/operator/nn/log_softmax.cu new file mode 100644 index 000000000000..8bff2777f7b9 --- /dev/null +++ b/src/operator/nn/log_softmax.cu @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file log_softmax.cu + * \brief GPU Implementation of log_softmax + */ +#include "./softmax-inl.h" +#include "../tensor/elemwise_unary_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(log_softmax) +.set_attr("FCompute", SoftmaxCompute); + +NNVM_REGISTER_OP(_backward_log_softmax) +.set_attr("FCompute", SoftmaxGradCompute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 5a581e4ea5ef..2abdf45ed0cf 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -149,103 +149,5 @@ NNVM_REGISTER_OP(_backward_softmax) .set_attr("FCompute", SoftmaxGradCompute); -NNVM_REGISTER_OP(softmin) -.describe(R"code(Applies the softmin function. - -The resulting array contains elements in the range (0,1) and the elements along the given axis sum -up to 1. - -.. math:: - softmin(\mathbf{z/t})_j = \frac{e^{-z_j/t}}{\sum_{k=1}^K e^{-z_k/t}} - -for :math:`j = 1, ..., K` - -t is the temperature parameter in softmax function. By default, t equals 1.0 - -Example:: - - x = [[ 1. 2. 3.] - [ 3. 2. 1.]] - - softmin(x,axis=0) = [[ 0.88079703, 0.5, 0.11920292], - [ 0.11920292, 0.5, 0.88079703]] - - softmin(x,axis=1) = [[ 0.66524094, 0.24472848, 0.09003057], - [ 0.09003057, 0.24472848, 0.66524094]] - -)code" ADD_FILELINE) -.set_attr_parser(ParamParser) -.set_attr("FListOutputNames", - [](const NodeAttrs& attrs) { - return std::vector{"output"}; -}) -.set_attr("FCompute", SoftmaxCompute) -.set_attr("FGradient", SoftmaxFGradient{"_backward_softmin"}) -.set_attr("FInferType", SoftmaxOpType) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr("FInferShape", ElemwiseShape<1, 1>) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 0}}; - }) -.add_argument("data", "NDArray-or-Symbol", "The input array.") -.add_arguments(SoftmaxParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_softmin) -.set_num_inputs(SoftmaxGradOpNumInputs) -.set_num_outputs(1) -.set_attr("FListInputNames", SoftmaxGradOpInputNames) -.set_attr("FInferShape", SoftmaxGradOpShape) -.set_attr("FInferType", SoftmaxGradOpType) -.set_attr("FInplaceOption", SoftmaxGradOpInplaceOption) -.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments") -.set_attr_parser(ParamParser) -.set_attr("FCompute", SoftmaxGradCompute); - -NNVM_REGISTER_OP(log_softmax) -.describe(R"code(Computes the log softmax of the input. -This is equivalent to computing softmax followed by log. - -Examples:: - - >>> x = mx.nd.array([1, 2, .1]) - >>> mx.nd.log_softmax(x).asnumpy() - array([-1.41702998, -0.41702995, -2.31702995], dtype=float32) - - >>> x = mx.nd.array( [[1, 2, .1],[.1, 2, 1]] ) - >>> mx.nd.log_softmax(x, axis=0).asnumpy() - array([[-0.34115392, -0.69314718, -1.24115396], - [-1.24115396, -0.69314718, -0.34115392]], dtype=float32) - - -)code") -.set_attr_parser(ParamParser) -.set_attr("FCompute", SoftmaxCompute) -.set_attr("FGradient", SoftmaxFGradient{"_backward_log_softmax"}) -.set_attr("FInferType", SoftmaxOpType) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr("FInferShape", ElemwiseShape<1, 1>) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 0}}; - }) -.add_argument("data", "NDArray-or-Symbol", "The input array.") -.add_arguments(SoftmaxParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_log_softmax) -.set_num_inputs(SoftmaxGradOpNumInputs) -.set_num_outputs(1) -.set_attr("FListInputNames", SoftmaxGradOpInputNames) -.set_attr("FInferShape", SoftmaxGradOpShape) -.set_attr("FInferType", SoftmaxGradOpType) -.set_attr("FInplaceOption", SoftmaxGradOpInplaceOption) -.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments") -.set_attr_parser(ParamParser) -.set_attr("FCompute", SoftmaxGradCompute); - } // namespace op } // namespace mxnet diff --git a/src/operator/nn/softmax.cu b/src/operator/nn/softmax.cu index 254e726d5e26..d5762cf0217f 100644 --- a/src/operator/nn/softmax.cu +++ b/src/operator/nn/softmax.cu @@ -19,8 +19,8 @@ /*! * Copyright (c) 2017 by Contributors - * \file softmax.cc - * \brief CPU Implementation of softmax + * \file softmax.cu + * \brief GPU Implementation of softmax */ #include "./softmax-inl.h" #include "../tensor/elemwise_unary_op.h" @@ -35,19 +35,5 @@ NNVM_REGISTER_OP(_backward_softmax) .set_attr("FCompute", SoftmaxGradCompute); -NNVM_REGISTER_OP(softmin) -.set_attr("FCompute", SoftmaxCompute); - -NNVM_REGISTER_OP(_backward_softmin) -.set_attr("FCompute", SoftmaxGradCompute); - -NNVM_REGISTER_OP(log_softmax) -.set_attr("FCompute", SoftmaxCompute); - -NNVM_REGISTER_OP(_backward_log_softmax) -.set_attr("FCompute", SoftmaxGradCompute); - } // namespace op } // namespace mxnet diff --git a/src/operator/nn/softmin.cc b/src/operator/nn/softmin.cc new file mode 100644 index 000000000000..0522c8c9b120 --- /dev/null +++ b/src/operator/nn/softmin.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file softmax.cc + * \brief CPU Implementation of softmin + */ +#include "./softmax-inl.h" +#include "../tensor/elemwise_unary_op.h" +#include "../tensor/elemwise_binary_op.h" +#include "../operator_common.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(softmin) +.describe(R"code(Applies the softmin function. + +The resulting array contains elements in the range (0,1) and the elements along the given axis sum +up to 1. + +.. math:: + softmin(\mathbf{z/t})_j = \frac{e^{-z_j/t}}{\sum_{k=1}^K e^{-z_k/t}} + +for :math:`j = 1, ..., K` + +t is the temperature parameter in softmax function. By default, t equals 1.0 + +Example:: + + x = [[ 1. 2. 3.] + [ 3. 2. 1.]] + + softmin(x,axis=0) = [[ 0.88079703, 0.5, 0.11920292], + [ 0.11920292, 0.5, 0.88079703]] + + softmin(x,axis=1) = [[ 0.66524094, 0.24472848, 0.09003057], + [ 0.09003057, 0.24472848, 0.66524094]] + +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FCompute", SoftmaxCompute) +.set_attr("FGradient", SoftmaxFGradient{"_backward_softmin"}) +.set_attr("FInferType", SoftmaxOpType) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "The input array.") +.add_arguments(SoftmaxParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_softmin) +.set_num_inputs(SoftmaxGradOpNumInputs) +.set_num_outputs(1) +.set_attr("FListInputNames", SoftmaxGradOpInputNames) +.set_attr("FInferShape", SoftmaxGradOpShape) +.set_attr("FInferType", SoftmaxGradOpType) +.set_attr("FInplaceOption", SoftmaxGradOpInplaceOption) +.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments") +.set_attr_parser(ParamParser) +.set_attr("FCompute", SoftmaxGradCompute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/nn/softmin.cu b/src/operator/nn/softmin.cu new file mode 100644 index 000000000000..d00d0bdad231 --- /dev/null +++ b/src/operator/nn/softmin.cu @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file softmin.cu + * \brief GPU Implementation of softmin + */ +#include "./softmax-inl.h" +#include "../tensor/elemwise_unary_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(softmin) +.set_attr("FCompute", SoftmaxCompute); + +NNVM_REGISTER_OP(_backward_softmin) +.set_attr("FCompute", SoftmaxGradCompute); + +} // namespace op +} // namespace mxnet From b7670c2ca275b7c327a9682f546aed782023f813 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 19 Aug 2019 15:43:44 -0700 Subject: [PATCH 10/15] Fix --- src/operator/nn/softmax-inl.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 0d462e8a07f3..fbefcc1b1057 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -526,9 +526,7 @@ __global__ void softmax_stride1_grad_kernel(const OType *out, const OType *ograd row[i] += igrad[my_row * M + i]; } } - if (Req == kAddTo) { - __syncthreads(); - } + __syncthreads(); LType* igrad_aligned = reinterpret_cast(igrad); From 9a406d11faba40b24f4336d235065eb1fb55130a Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 20 Aug 2019 08:32:15 -0700 Subject: [PATCH 11/15] Trigger CI From 46175c01f49aae1f754191037b0a5ffc9da53d80 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 21 Aug 2019 09:44:43 -0700 Subject: [PATCH 12/15] Moving get_rows_per_block to common place --- src/common/cuda_utils.cc | 35 +++++++++++++++++++++++++++++++++++ src/common/cuda_utils.h | 15 ++++++++++++++- src/operator/nn/softmax-inl.h | 30 ++++-------------------------- 3 files changed, 53 insertions(+), 27 deletions(-) diff --git a/src/common/cuda_utils.cc b/src/common/cuda_utils.cc index 2f8835e14d6b..2a9b3ec4e907 100644 --- a/src/common/cuda_utils.cc +++ b/src/common/cuda_utils.cc @@ -33,6 +33,25 @@ namespace mxnet { namespace common { namespace cuda { +namespace { + bool IsPower2(size_t N) { + return ((N & (N - 1)) == 0) && N != 0; + } + + size_t RoundToPower2(size_t N) { + size_t ret = 1; + size_t copyN = N; + while (N >= 2) { + ret *= 2; + N /= 2; + } + if (ret < copyN) { + ret *= 2; + } + return ret; + } +} // namespace + int get_load_type(size_t N) { using namespace mshadow; if (N % 8 == 0) { @@ -45,6 +64,22 @@ int get_load_type(size_t N) { return kUint8; } } + +int get_rows_per_block(size_t row_size, int num_threads_per_block) { + const int warp_size = 32; + CHECK(IsPower2(num_threads_per_block)) + << "Number of threads in a block must be power of 2 to use get_rows_per_block function"; + // How many read instructions should 1 thread at least do + const int read_instructions = 2; + const int desired_num_threads_per_row = (row_size + read_instructions - 1) / read_instructions; + int desired_num_warps_per_row = (desired_num_threads_per_row + warp_size - 1) / warp_size; + int actual_num_warps_per_row = std::min(desired_num_warps_per_row, + num_threads_per_block / warp_size); + // actual number of warps needs to be power of 2 + actual_num_warps_per_row = RoundToPower2(desired_num_warps_per_row); + return num_threads_per_block / (warp_size * actual_num_warps_per_row); +} + } // namespace cuda } // namespace common } // namespace mxnet diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index fb5454cb5d12..dd98a0e6c966 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -327,7 +327,8 @@ class DeviceStore { bool restore_; }; -/*! \brief Get the largest datatype suitable to read +/*! + * \brief Get the largest datatype suitable to read * requested number of bytes. * * \input Number of bytes to be read @@ -336,6 +337,18 @@ class DeviceStore { */ int get_load_type(size_t N); +/*! + * \brief Determine how many rows in a 2D matrix should a block + * of threads handle based on the row size and the number + * of threads in a block. + * \param row_size Size of the row expressed in the number of reads required to fully + * load it. For example, if the row has N elements, but each thread + * reads 2 elements with a single read, row_size should be N / 2. + * \param num_threads_per_block Number of threads in a block. + * \return the number of rows that should be handled by a single block. + */ +int get_rows_per_block(size_t row_size, int num_threads_per_block); + } // namespace cuda } // namespace common } // namespace mxnet diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index fbefcc1b1057..93522645b642 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -395,30 +395,6 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp } } -namespace { - -int get_rows_per_block(size_t N) { - const int warp_size = 32; - // How many read instructions should 1 thread at least do - const int read_instructions = 2; - const int num_threads = (N + read_instructions - 1) / read_instructions; - int num_warps = (num_threads + warp_size - 1) / warp_size; - // num_warps needs to be power of 2 - int used_num_warps = 1; - num_warps = std::min(num_warps, softmax_threads_per_block / warp_size); - int tmp = num_warps; - while (tmp >= 2) { - used_num_warps *= 2; - tmp /= 2; - } - if (used_num_warps < num_warps) { - used_num_warps *= 2; - } - return softmax_threads_per_block / (warp_size * used_num_warps); -} - -} // namespace - template inline void Softmax(Stream *s, DType *in, OType *out, IType *length, @@ -439,7 +415,8 @@ inline void Softmax(Stream *s, DType *in, OType *out, IType *length, std::is_same::value) { int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType)); MXNET_LOAD_TYPE_SWITCH(ltype, LType, { - int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType)); + int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType), + softmax_threads_per_block); int nblocks = (N + rows_per_block - 1) / rows_per_block; CHECK_LE(sizeof(DType), sizeof(LType)); softmax_stride1_compute_kernel @@ -592,7 +569,8 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, std::is_same::value) { int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType)); MXNET_LOAD_TYPE_SWITCH(ltype, LType, { - int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType)); + int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType), + softmax_threads_per_block); int nblocks = (N + rows_per_block - 1) / rows_per_block; CHECK_LE(sizeof(DType), sizeof(LType)); softmax_stride1_grad_kernel From 9340a33723fe40e66585e766e41b635e98b26101 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 21 Aug 2019 10:28:22 -0700 Subject: [PATCH 13/15] Fix --- src/operator/nn/softmax-inl.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 93522645b642..73e0ccce4328 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -415,8 +415,9 @@ inline void Softmax(Stream *s, DType *in, OType *out, IType *length, std::is_same::value) { int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType)); MXNET_LOAD_TYPE_SWITCH(ltype, LType, { - int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType), - softmax_threads_per_block); + int rows_per_block = mxnet::common::cuda::get_rows_per_block(M * + sizeof(DType) / sizeof(LType), + softmax_threads_per_block); int nblocks = (N + rows_per_block - 1) / rows_per_block; CHECK_LE(sizeof(DType), sizeof(LType)); softmax_stride1_compute_kernel @@ -569,8 +570,9 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, std::is_same::value) { int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType)); MXNET_LOAD_TYPE_SWITCH(ltype, LType, { - int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType), - softmax_threads_per_block); + int rows_per_block = mxnet::common::cuda::get_rows_per_block(M * + sizeof(DType) / sizeof(LType), + softmax_threads_per_block); int nblocks = (N + rows_per_block - 1) / rows_per_block; CHECK_LE(sizeof(DType), sizeof(LType)); softmax_stride1_grad_kernel From c020cfd6323df21ea2f4da30409aff7dc40ae137 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 21 Aug 2019 10:48:15 -0700 Subject: [PATCH 14/15] Fix lint --- src/common/cuda_utils.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/common/cuda_utils.cc b/src/common/cuda_utils.cc index 2a9b3ec4e907..6551d4c1f193 100644 --- a/src/common/cuda_utils.cc +++ b/src/common/cuda_utils.cc @@ -23,6 +23,8 @@ * \brief Common CUDA utilities. */ +#include + #include #include #include "cuda_utils.h" From 61a2aad489417284a2175003f832981c5d049057 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 21 Aug 2019 11:04:35 -0700 Subject: [PATCH 15/15] Actually fix lint --- src/common/cuda_utils.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/common/cuda_utils.cc b/src/common/cuda_utils.cc index 6551d4c1f193..f38b2f8b5490 100644 --- a/src/common/cuda_utils.cc +++ b/src/common/cuda_utils.cc @@ -23,10 +23,11 @@ * \brief Common CUDA utilities. */ -#include - #include #include + +#include + #include "cuda_utils.h" #if MXNET_USE_CUDA