diff --git a/src/common/cuda/rtc/backward_functions-inl.h b/src/common/cuda/rtc/backward_functions-inl.h index 168dc686e7ad..50f0c671bf48 100644 --- a/src/common/cuda/rtc/backward_functions-inl.h +++ b/src/common/cuda/rtc/backward_functions-inl.h @@ -32,217 +32,217 @@ const char backward_function_definitions[] = R"code( namespace op { template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_relu(const DTypeGrad grad, const DType val) { if (isnan(val)) return val; return val > 0 ? grad : 0; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_sigmoid(const DTypeGrad grad, const DType out) { return grad * out * (1 - out); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_softrelu(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; return grad * sigmoid(v); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_softsign(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; const auto ap1 = 1 + op::abs(v); return grad / (ap1 * ap1); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_abs(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; return grad * op::sign(v); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_exp(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; return grad * op::exp(v); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_expm1(const DTypeGrad grad, const DType val) { return backward_exp(grad, val); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_log(const DTypeGrad grad, const DType val) { return grad / val; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_log10(const DTypeGrad grad, const DType val) { return grad / (val * op::log(static_cast(10))); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_log2(const DTypeGrad grad, const DType val) { return grad / (val * op::log(static_cast(2))); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_log1p(const DTypeGrad grad, const DType val) { return grad / (1 + val); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_sin(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; return grad * op::cos(v); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_cos(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; return -grad * op::sin(v); } // Uses output from tan template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_tan(const DTypeGrad grad, const DType out) { return grad * (out * out + 1); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_arcsin(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; return grad / op::sqrt(1 - v*v); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_arccos(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; return -grad / op::sqrt(1 - v*v); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_arctan(const DTypeGrad grad, const DType val) { return grad / (1 + val*val); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_degrees(const DTypeGrad grad, const DType /* val */) { return op::degrees(grad); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_radians(const DTypeGrad grad, const DType /* val */) { return op::radians(grad); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_sinh(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; return grad * op::cosh(v); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_cosh(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; return grad * op::sinh(v); } // Uses tanh output template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_tanh(const DTypeGrad grad, const DType out) { return grad * (1 - out * out); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_arcsinh(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; return grad / op::sqrt(v * v + 1); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_arccosh(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; return grad / op::sqrt(v * v - 1); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_arctanh(const DTypeGrad grad, const DType val) { return grad / (1 - val * val); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_sqrt(const DTypeGrad grad, const DType out) { return 0.5 * grad / out; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_rsqrt(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; const auto inv = 1 / v; return -0.5 * grad * op::sqrt(inv) * inv; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_cbrt(const DTypeGrad grad, const DType out) { return grad / (3.0f * out * out); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_rcbrt(const DTypeGrad grad, const DType val) { - const typename type_util::mixed_type::type v = val; + const mixed_type v = val; const auto inv = 1 / v; return -1.f/3.f * grad * op::cbrt(inv) * inv; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_square(const DTypeGrad grad, const DType val) { return 2 * val * grad; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type rdiv_grad(const DType val, const DType2 val2) { return -val2 / (val * val); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type div_grad(const DType val, const DType2 val2) { - const typename type_util::mixed_type::type temp = val2; + const mixed_type temp = val2; return op::reciprocal(temp); } @@ -283,87 +283,87 @@ __device__ inline DType rmod_grad(const DType val, } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type power_grad(const DType val, const DType2 val2) { return op::power(val, val2 - 1.f) * val2; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type power_rgrad(const DType val, const DType2 val2) { - const typename type_util::mixed_type::type temp = val; + const mixed_type temp = val; return op::power(val, val2) * op::log(temp); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type rpower_grad(const DType val, const DType2 val2) { - const typename type_util::mixed_type::type temp = val2; + const mixed_type temp = val2; return val * op::log(temp); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type hypot_grad_left(const DType val, const DType2 val2) { return val / op::hypot(val, val2); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type hypot_grad_right(const DType val, const DType2 val2) { return val2 / op::hypot(val, val2); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type copysign_grad(const DType val, const DType2 val2) { return (val >= 0 && val2 >= 0) || (val < 0 && val2 < 0) ? 1 : -1; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type arctan2_grad(const DType val, const DType2 val2) { return val2 / (val * val + val2 * val2); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type rarctan2_grad(const DType val, const DType2 val2) { return val / (val * val + val2 * val2); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type arctan2_rgrad(const DType val, const DType2 val2) { return -rarctan2_grad(val, val2); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type ldexp_grad(const DType val, const DType2 val2) { return op::power(static_cast(2), val2); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type rldexp_grad(const DType val, const DType2 val2) { - using mixed_type = typename type_util::mixed_type::type; - return val2 * op::power(static_cast(2), val) * op::log(static_cast(2)); + using type = mixed_type; + return val2 * op::power(static_cast(2), val) * op::log(static_cast(2)); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_clip(const DTypeGrad grad, const DType val, const float a_min, const float a_max) { if (val > a_max || val < a_min) { @@ -374,35 +374,32 @@ backward_clip(const DTypeGrad grad, const DType val, } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_reciprocal(const DTypeGrad grad, const DType val) { return -grad / (val * val); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_erf(const DTypeGrad grad, const DType val) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type v = val; - constexpr mixed_type my_pi = pi; + const mixed_type v = val; + constexpr mixed_type my_pi = pi; return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_erfinv(const DTypeGrad grad, const DType val) { - using mixed_type = typename type_util::mixed_type::type; - constexpr mixed_type my_pi = pi; - const mixed_type g = grad; - const mixed_type v = val; + constexpr mixed_type my_pi = pi; + const mixed_type g = grad; + const mixed_type v = val; return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_gamma(const DTypeGrad grad, const DType val) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type v = val; + const mixed_type v = val; if (type_util::is_same::value) { return grad * op::gamma(v) * op::special_functions::cephes::psi(v); } else { @@ -411,10 +408,9 @@ backward_gamma(const DTypeGrad grad, const DType val) { } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_gammaln(const DTypeGrad grad, const DType val) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type v = val; + const mixed_type v = val; if (type_util::is_same::value) { return grad * op::special_functions::cephes::psi(v); } else { @@ -423,10 +419,9 @@ backward_gammaln(const DTypeGrad grad, const DType val) { } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_digamma(const DTypeGrad grad, const DType val) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type v = val; + const mixed_type v = val; if (type_util::is_same::value) { return grad * op::special_functions::trigamma(v); } else { @@ -435,7 +430,7 @@ backward_digamma(const DTypeGrad grad, const DType val) { } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type backward_gelu(const DTypeGrad grad, const DType val) { return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) + val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f)); diff --git a/src/common/cuda/rtc/forward_functions-inl.h b/src/common/cuda/rtc/forward_functions-inl.h index f85916f1ef96..f4d08e6d1a60 100644 --- a/src/common/cuda/rtc/forward_functions-inl.h +++ b/src/common/cuda/rtc/forward_functions-inl.h @@ -32,6 +32,7 @@ const char function_definitions_util[] = R"code( #define INT_MAX (2147483647) namespace op { +using type_util::mixed_type; template struct LoadType { @@ -241,44 +242,44 @@ __device__ inline bool_t isfinite(const DType val) { } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type add(const DType a, const DType2 b) { return a + b; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type sub(const DType a, const DType2 b) { return a - b; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type rsub(const DType a, const DType2 b) { return b - a; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type mul(const DType a, const DType2 b) { return a * b; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type div(const DType a, const DType2 b) { return a / b; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type rdiv(const DType a, const DType2 b) { return b / a; } #define DEFINE_BINARY_MATH_FUNC(name, double_version, float_version) \ template \ -__device__ inline typename type_util::mixed_type::type \ +__device__ inline mixed_type \ name (const DType a, const DType2 b) { \ if (type_util::has_double_or_integral::value) { \ return double_version ((double)a, (double)b); \ @@ -288,7 +289,7 @@ name (const DType a, const DType2 b) { \ } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type power (const DType a, const DType2 b) { if (type_util::has_double::value) { return ::pow ((double)a, (double)b); \ @@ -298,34 +299,34 @@ power (const DType a, const DType2 b) { } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type rpow(const DType a, const DType2 b) { return power(b, a); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type max(const DType a, const DType2 b) { if (isnan(a)) return a; return a > b ? a : b; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type fmax(const DType a, const DType2 b) { if (isnan(b)) return a; return a > b ? a : b; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type min(const DType a, const DType2 b) { if (isnan(a)) return a; return a < b ? a : b; } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type fmin(const DType a, const DType2 b) { if (isnan(b)) return a; return a < b ? a : b; @@ -334,7 +335,7 @@ fmin(const DType a, const DType2 b) { DEFINE_BINARY_MATH_FUNC(hypot, ::hypot, ::hypotf) template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type mod(const DType a, const DType2 b) { if (b == 0) { return 0; @@ -359,7 +360,7 @@ mod(const DType a, const DType2 b) { } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type fmod(const DType a, const DType2 b) { if (b == 0) { return 0; @@ -368,110 +369,98 @@ fmod(const DType a, const DType2 b) { } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type rmod(const DType a, const DType2 b) { return op::mod(b, a); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type rfmod(const DType a, const DType2 b) { return op::fmod(b, a); } template __device__ inline DType equal(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a == real_b ? 1 : 0; } template __device__ inline DType not_equal(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a != real_b ? 1 : 0; } template __device__ inline DType greater(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a > real_b ? 1 : 0; } template __device__ inline DType greater_equal(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a >= real_b ? 1 : 0; } template __device__ inline DType less(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a < real_b ? 1 : 0; } template __device__ inline DType less_equal(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a <= real_b ? 1 : 0; } template __device__ inline bool_t np_equal(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a == real_b ? true : false; } template __device__ inline bool_t np_not_equal(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a != real_b ? true : false; } template __device__ inline bool_t np_greater(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a > real_b ? true : false; } template __device__ inline bool_t np_greater_equal(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a >= real_b ? true : false; } template __device__ inline bool_t np_less(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a < real_b ? true : false; } template __device__ inline bool_t np_less_equal(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a <= real_b ? true : false; } @@ -501,7 +490,7 @@ __device__ inline DType2 rcopysign(const DType a, const DType2 b) { } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type lcm(const DType a, const DType2 b) { if (type_util::is_integral::value && type_util::is_integral::value) { @@ -542,7 +531,7 @@ lcm(const DType a, const DType2 b) { } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type gcd(const DType a, const DType2 b) { if (type_util::is_integral::value && type_util::is_integral::value) { @@ -585,42 +574,39 @@ gcd(const DType a, const DType2 b) { } template -__device__ inline typename type_util::mixed_type::type bitwise_xor(const DType a, +__device__ inline mixed_type bitwise_xor(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a ^ real_b; } template -__device__ inline typename type_util::mixed_type::type bitwise_or(const DType a, +__device__ inline mixed_type bitwise_or(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a | real_b; } template -__device__ inline typename type_util::mixed_type::type bitwise_and(const DType a, +__device__ inline mixed_type bitwise_and(const DType a, const DType2 b) { - using mixed_type = typename type_util::mixed_type::type; - const mixed_type real_a = a; - const mixed_type real_b = b; + const mixed_type real_a = a; + const mixed_type real_b = b; return real_a & real_b; } DEFINE_BINARY_MATH_FUNC(arctan2, ::atan2, ::atan2f) template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type rarctan2(const DType a, const DType2 b) { return arctan2(b, a); } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type ldexp(const DType a, const DType2 b) { if (type_util::has_double_or_integral::value) { return a * ::pow(2.0, static_cast(b)); @@ -630,7 +616,7 @@ ldexp(const DType a, const DType2 b) { } template -__device__ inline typename type_util::mixed_type::type +__device__ inline mixed_type rldexp(const DType a, const DType2 b) { return ldexp(b, a); } diff --git a/src/common/cuda/rtc/reducer-inl.h b/src/common/cuda/rtc/reducer-inl.h index 93b702788c46..259d0e060a57 100644 --- a/src/common/cuda/rtc/reducer-inl.h +++ b/src/common/cuda/rtc/reducer-inl.h @@ -94,6 +94,405 @@ struct sum { residual = 0; } }; + +/*! \brief maximum reducer */ +struct maximum { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { // NOLINT(*) + if (!util::isnan(dst)) { + if (!(dst >= src)) dst = src; + } + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, + volatile DType& none) { + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv = -2*DBL_MAX; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &none) { + SetInitValue(initv); + } +}; + +/*! \brief minimum reducer */ +struct minimum { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { + if (!util::isnan(dst)) { + if (!(dst <= src)) dst = src; + } + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, + volatile DType& none) { + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv = 2*DBL_MAX; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &none) { + SetInitValue(initv); + } +}; + +/*! \brief product reducer */ +struct product { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { + dst = op::mul(dst, src); + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, + volatile DType& none) { + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv = 1; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &none) { + SetInitValue(initv); + } +}; + +/*! \brief sum reducer that ignores NaN values in the input */ +struct nansum { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { + if (util::isnan(src)) return; + dst = op::add(dst, src); + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType src, + volatile DType& residual) { + if (util::isnan(src)) return; + DType y = src - residual; + DType t = dst + y; + residual = (t - dst) - y; + dst = t; + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + DType t1 = dst_val + src_val; + DType e = t1 - src_val; + DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual; + dst_val = t1 + t2; + dst_residual = t2 - (dst_val - t1); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType & initv) { + initv = 0; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &residual) { + SetInitValue(initv); + residual = 0; + } +}; + +/*! \brief product reducer that ignores NaN values in the input */ +struct nanprod { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { + if (util::isnan(src)) return; + dst = op::mul(dst, src); + } + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType src, + volatile DType& none) { + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType & initv) { + initv = 1; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &none) { + SetInitValue(initv); + } +}; + +struct nrm2 { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile AType& sum_of_squares, volatile DType src) { + sum_of_squares = op::add(sum_of_square, src * src); + } + /*! \brief do stable reduction into dst */ + template + __device__ inline static void Reduce(volatile AType& sum_of_squares, + volatile DType src, volatile DType& scale) { + if (src != 0) { + DType abs = op::abs(src); + if (scale < abs) { + sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs); + scale = abs; + } else { + sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale); + } + } + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + dst_val = op::add(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale, + volatile DType& src_ssq, volatile DType& src_scale) { + if (dst_scale != 0 && dst_scale >= src_scale) { + dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale / dst_scale); + } else if (src_scale != 0 && dst_scale < src_scale) { + dst_ssq = src_ssq + dst_ssq * (dst_scale / src_scale) * (dst_scale / src_scale); + dst_scale = src_scale; + } + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& sum_of_squares) { + sum_of_squares = op::sqrt(sum_of_squares); + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& sum_of_squares, volatile DType& scale) { + sum_of_squares = scale * op::sqrt(sum_of_squares); + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &sum_of_squares) { + sum_of_squares = 0; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &sum_of_squares, DType &scale) { + SetInitValue(sum_of_squares); + scale = 0; + } +}; + +struct nrmlp { + double lp; + /* \brief power for Lp norm */ + __device__ inline static double lp_power(volatile double src, volatile double p) { + if (p != 0.0) { + if (src == 0.0) { + return src; + } else { + return op::power(src, p); + } + } else { // 0-norm, sparsity + return static_cast(src != 0); + } + } + + /*! \brief do reduction into dst */ + template + __device__ inline void Reduce(volatile AType& sum_of_powers, volatile DType src) { + if (src != 0) { + sum_of_powers += AType(lp_power(static_cast(src), lp)); + } + } + + /*! \brief do stable reduction into dst */ + template + __device__ inline void Reduce(volatile AType& sum_of_powers, volatile DType src, + volatile DType& scale) { + if (src != 0) { + DType src_abs = op::abs(src); + if (scale < src_abs) { + sum_of_powers = sum_of_powers * AType(lp_power(static_cast(scale / src_abs), lp)); + sum_of_powers = sum_of_powers + 1; + scale = src_abs; + } else { + sum_of_powers = sum_of_powers + AType(lp_power(static_cast(src_abs / scale), lp)); + } + } + } + + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + dst_val = dst_val + src_val; + } + + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale, + volatile DType& src_ssq, volatile DType& src_scale) { + if (dst_scale != 0 && dst_scale >= src_scale) { + dst_ssq = dst_ssq + src_ssq * DType(lp_power(static_cast(src_scale / dst_scale), 2)); + } else if (src_scale != 0 && dst_scale < src_scale) { + dst_ssq = src_ssq + dst_ssq * DType(lp_power(static_cast(dst_scale / src_scale), 2)); + dst_scale = src_scale; + } + } + + /*! \brief finalize reduction result */ + template + __device__ inline void Finalize(volatile DType& sum_of_powers) { + if (lp != 0.0) { + sum_of_powers = DType(lp_power(static_cast(sum_of_powers), 1.0 / lp)); + } + } + + /*! \brief finalize reduction result */ + template + __device__ inline void Finalize(volatile DType& sum_of_powers, volatile DType& scale) { + if (lp != 0.0) { + sum_of_powers = scale * DType(lp_power(static_cast(sum_of_powers), 1.0 / lp)); + } + } + + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &sum_of_powers) { + sum_of_powers = 0; + } + + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &sum_of_powers, DType &scale) { + SetInitValue(sum_of_powers); + scale = 0; + } +}; } // namespace red )code"; diff --git a/src/common/cuda/rtc/util-inl.h b/src/common/cuda/rtc/util-inl.h index 372390fdc117..b4266030be1f 100644 --- a/src/common/cuda/rtc/util-inl.h +++ b/src/common/cuda/rtc/util-inl.h @@ -174,74 +174,97 @@ struct enable_if { }; template -struct mixed_type; +struct mixed_type_helper; template -struct mixed_type::value>::type> { +struct mixed_type_helper::value>::type> { using type = float64; }; template -struct mixed_type { +struct mixed_type_helper { using type = float64; }; template -struct mixed_type::value && - !is_same::value>::type> { +struct mixed_type_helper::value && + !is_same::value>::type> { using type = float32; }; template -struct mixed_type::value>::type> { +struct mixed_type_helper::value>::type> { using type = float32; }; template -struct mixed_type::value || - is_integral::value>::type> { +struct mixed_type_helper::value || + is_integral::value>::type> { using type = float16; }; template -struct mixed_type::value>::type> { +struct mixed_type_helper::value>::type> { using type = float16; }; template -struct mixed_type::value && - is_integral::value && - !is_same::value && - sizeof(T) <= sizeof(U)>::type> { +struct mixed_type_helper::value && + is_integral::value && + !is_same::value && + sizeof(T) <= sizeof(U)>::type> { using type = U; }; template -struct mixed_type::value && - is_integral::value && - !is_same::value && - sizeof(T) < sizeof(U)>::type> { +struct mixed_type_helper::value && + is_integral::value && + !is_same::value && + sizeof(T) < sizeof(U)>::type> { using type = U; }; template -struct mixed_type::value && - sizeof(T) < sizeof(bool_t)>::type> { +struct mixed_type_helper::value && + sizeof(T) < sizeof(bool_t)>::type> { using type = index_t; }; template -struct mixed_type::value && - sizeof(T) < sizeof(bool_t)>::type> { +struct mixed_type_helper::value && + sizeof(T) < sizeof(bool_t)>::type> { using type = index_t; }; template -struct mixed_type::value && - sizeof(T) == sizeof(bool_t)>::type> { +struct mixed_type_helper::value && + sizeof(T) == sizeof(bool_t)>::type> { using type = T; }; +template +struct multi_mixed_type_helper; + +template <> +struct multi_mixed_type_helper<> { + using type = void; +}; + +template +struct multi_mixed_type_helper { + using type = T; +}; + +template +struct multi_mixed_type_helper { + using type = typename mixed_type_helper::type>::type; +}; + +template +using mixed_type = typename multi_mixed_type_helper::type; + } // namespace type_util )code"; @@ -254,6 +277,7 @@ enum class OpReqType { }; constexpr int kRTCMaxThreadsPerBlock = 512; +constexpr int warp_size = 32; namespace util { @@ -377,6 +401,49 @@ __device__ inline bool isnan(volatile const float16 &val) { return ::isnan(__half2float(const_cast(val))); } +template +__device__ inline T warp_reduce(T value, OP redfun) { +#pragma unroll + for (int i = warp_size / 2; i >= 1; i /= 2) { + if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value, i)); + } + return value; +} + +template +__device__ inline T grouped_warp_reduce(T value, OP redfun, const int group_size) { + for (int i = 1; i < group_size; i *= 2) { + value = redfun(value, __shfl_down_sync(0xffffffff, value, i)); + } + return value; +} + +template +__device__ inline T grouped_warp_allreduce(T value, OP redfun, const int group_size) { + value = grouped_warp_reduce(value, redfun, group_size); + return __shfl_sync(0xffffffff, value, 0, group_size); +} + +template +__device__ inline T strided_grouped_warp_reduce(T value, OP redfun, const int group_size) { + for (int i = warp_size / 2; i >= group_size; i /= 2) { + value = redfun(value, __shfl_down_sync(0xffffffff, value, i)); + } + return value; +} + +template +__device__ inline T strided_grouped_warp_allreduce(T value, OP redfun, const int group_size) { + value = strided_grouped_warp_reduce(value, redfun, group_size); + for (int i = group_size; i < warp_size; i *= 2) { + T tmp = __shfl_up_sync(0xffffffff, value, i); + if (threadIdx.x % warp_size >= i) { + value = tmp; + } + } + return value; +} + } // namespace util )code"; } // namespace rtc diff --git a/src/common/cuda/rtc/vectorization-inl.h b/src/common/cuda/rtc/vectorization-inl.h index 5cbc4599db33..96205fceab3e 100644 --- a/src/common/cuda/rtc/vectorization-inl.h +++ b/src/common/cuda/rtc/vectorization-inl.h @@ -41,6 +41,8 @@ const char vectorization_support_string[] = R"code( namespace vector { +constexpr int vectorized_kernel_thread_num = 512; + template struct VectorType { static_assert(size <= 32, "VectorType needs to have size of at most 32B"); @@ -166,7 +168,7 @@ class VectorizedAccessor { if (aligned) { alignment_ = 0; aligned_ptr_ = reinterpret_cast(ptr); - n_elems_ = (size + nvec- 1) / nvec; + n_elems_ = (size + nvec - 1) / nvec; } else { size_t ptr_as_number = reinterpret_cast(ptr); alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType); @@ -360,6 +362,8 @@ constexpr int vectorized_kernel_thread_num = 512; * \param lead_input_num number of input to use for checking alignment * (in case only a subset of inputs is used vectorized). * Default is 0. + * \param blocks if provided and not 0, will launch the specified number of thread blocks. + * Default is 0. */ template void VectorizedKernelRTCLauncher(const std::string ¶meters, @@ -373,7 +377,8 @@ void VectorizedKernelRTCLauncher(const std::string ¶meters, const std::vector &inputs, const std::vector &outputs, const int dev_id, - const int lead_input_num = 0) { + const int lead_input_num = 0, + const index_t blocks = 0) { const index_t N = lead_dim * other_dim; nvec = std::min(nvec, 4); // Use at most 4-wide vectors if (N != 0) { @@ -435,11 +440,16 @@ void VectorizedKernelRTCLauncher(const std::string ¶meters, lead_dim, nvec, common::mshadow_type_info( inputs[lead_input_num].type_flag_).size); - size_t num_elements = other_dim * num_aligned_elements; constexpr int threads = vectorized_kernel_thread_num; - constexpr int max_blocks = 65535; - index_t blocks = std::min(static_cast((num_elements + threads - 1) / threads), - max_blocks); + index_t num_blocks; + if (blocks != 0) { + num_blocks = blocks; + } else { + size_t num_elements = other_dim * num_aligned_elements; + num_blocks = (num_elements + threads - 1) / threads; + constexpr int max_blocks = 65535; + num_blocks = std::min(static_cast(num_blocks), max_blocks); + } std::vector args = {¶ms, &lead_dim, &other_dim, &N, &num_aligned_elements}; auto function = common::cuda::rtc::get_function(kernel_builder, @@ -448,7 +458,7 @@ void VectorizedKernelRTCLauncher(const std::string ¶meters, dev_id); common::cuda::rtc::launch(function, - {static_cast(blocks), 1, 1}, + {static_cast(num_blocks), 1, 1}, {static_cast(threads), 1, 1}, 0, s, &args); } diff --git a/src/common/cuda/utils.cc b/src/common/cuda/utils.cc index b87c39386604..7aa936dc9d4d 100644 --- a/src/common/cuda/utils.cc +++ b/src/common/cuda/utils.cc @@ -29,6 +29,7 @@ #include #include "utils.h" +#include "../utils.h" #if MXNET_USE_CUDA @@ -36,25 +37,6 @@ namespace mxnet { namespace common { namespace cuda { -namespace { - bool IsPower2(size_t N) { - return ((N & (N - 1)) == 0) && N != 0; - } - - size_t RoundToPower2(size_t N) { - size_t ret = 1; - size_t copyN = N; - while (N >= 2) { - ret *= 2; - N /= 2; - } - if (ret < copyN) { - ret *= 2; - } - return ret; - } -} // namespace - int get_load_type(size_t N) { using namespace mshadow; if (N % 8 == 0) { diff --git a/src/common/cuda/utils.h b/src/common/cuda/utils.h index fc4d40c26e1b..a203ba55a773 100644 --- a/src/common/cuda/utils.h +++ b/src/common/cuda/utils.h @@ -811,6 +811,14 @@ __device__ inline T warp_reduce(T value, OP redfun) { return value; } +template +__device__ inline T grouped_warp_allreduce(T value, OP redfun, const int group_size) { + for (int i = 1; i < group_size; i *= 2) { + value = redfun(value, __shfl_down_sync(0xffffffff, value, i)); + } + return __shfl_sync(0xffffffff, value, 0, group_size); +} + template __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) { float v = static_cast(value); diff --git a/src/common/utils.h b/src/common/utils.h index dfd32ac6f311..40376e993a0b 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -977,6 +977,27 @@ inline void AlignedMemFree(void* ptr) { } +inline index_t div_round(const index_t a, const index_t b) { + return (a + b - 1) / b; +} + +inline bool IsPower2(size_t N) { + return ((N & (N - 1)) == 0) && N != 0; +} + +inline size_t RoundToPower2(size_t N) { + size_t ret = 1; + size_t copyN = N; + while (N >= 2) { + ret *= 2; + N /= 2; + } + if (ret < copyN) { + ret *= 2; + } + return ret; +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 05c8b9a61278..2251ff81ea04 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -31,6 +31,7 @@ #include #include #include +#include #include "./operator_tune.h" #include "../engine/openmp.h" @@ -367,40 +368,30 @@ struct AccType { break; \ case mshadow::kUint8: \ { \ - typedef uint8_t DType; \ - typedef uint8_t AType; \ LOG(FATAL) << "This operation only support " \ "floating point types not uint8"; \ } \ break; \ case mshadow::kInt8: \ { \ - typedef int8_t DType; \ - typedef int8_t AType; \ LOG(FATAL) << "This operation only support " \ "floating point types not int8"; \ } \ break; \ case mshadow::kInt32: \ { \ - typedef int32_t DType; \ - typedef int32_t AType; \ LOG(FATAL) << "This operation only support " \ "floating point types, not int32"; \ } \ break; \ case mshadow::kInt64: \ { \ - typedef int64_t DType; \ - typedef int64_t AType; \ LOG(FATAL) << "This operation only support " \ "floating point types, not int64"; \ } \ break; \ case mshadow::kBool: \ { \ - typedef bool DType; \ - typedef int64_t AType; \ LOG(FATAL) << "This operation only support " \ "floating point types, not bool"; \ } \ @@ -475,21 +466,18 @@ struct AccType { switch (type) { \ case mshadow::kFloat32: \ { \ - typedef float DType; \ LOG(FATAL) << "This operation only support " \ "integer types, not float32"; \ } \ break; \ case mshadow::kFloat64: \ { \ - typedef double DType; \ LOG(FATAL) << "This operation only support " \ "integer types, not float64"; \ } \ break; \ case mshadow::kFloat16: \ { \ - typedef mshadow::half::half_t DType; \ LOG(FATAL) << "This operation only support " \ "integer types, not float16"; \ } \ @@ -532,21 +520,18 @@ struct AccType { switch (type) { \ case mshadow::kFloat32: \ { \ - typedef float DType; \ LOG(FATAL) << "This operation only support " \ "integer types, not float32"; \ } \ break; \ case mshadow::kFloat64: \ { \ - typedef double DType; \ LOG(FATAL) << "This operation only support " \ "integer types, not float64"; \ } \ break; \ case mshadow::kFloat16: \ { \ - typedef mshadow::half::half_t DType; \ LOG(FATAL) << "This operation only support " \ "integer types, not float16"; \ } \ diff --git a/src/operator/nn/log_softmax.cu b/src/operator/nn/log_softmax.cu index 396a4e8e2cb3..485290dc4fa3 100644 --- a/src/operator/nn/log_softmax.cu +++ b/src/operator/nn/log_softmax.cu @@ -29,11 +29,10 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(log_softmax) -.set_attr("FCompute", SoftmaxCompute); +.set_attr("FCompute", SoftmaxRTCCompute{"log_softmax_fwd"}); NNVM_REGISTER_OP(_backward_log_softmax) -.set_attr("FCompute", SoftmaxGradCompute); +.set_attr("FCompute", SoftmaxRTCGradCompute{"op::left", "log_softmax_bwd"}); NNVM_REGISTER_OP(masked_log_softmax) .set_attr("FCompute", MaskedSoftmaxCompute *s, DType *out, DType *ograd, } #ifdef __CUDACC__ -template -__global__ void softmax_compute_kernel(DType *in, OType *out, IType *length, - index_t M, int axis, Shape sshape, - Shape stride, const double temperature) { - const unsigned x_size = 1 << x_bits; - __shared__ AType smem[x_size]; - index_t sa = stride[axis]; - index_t base = unravel_dot(blockIdx.x, sshape, stride); - index_t x = threadIdx.x; - const index_t len = length == nullptr ? M : static_cast(length[blockIdx.x]); - - red::maximum::SetInitValue(smem[x]); - for (index_t i = x; i < len; i += x_size) { - smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]); - } - __syncthreads(); - cuda::Reduce1D(smem); - __syncthreads(); - DType smax = smem[0]; - __syncthreads(); - - red::sum::SetInitValue(smem[x]); - DType val; - for (index_t i = x; i < len; i += x_size) { - val = negate ? -in[base + i*sa]:in[base + i*sa]; - smem[x] += static_cast(expf((val - smax) / static_cast(temperature))); - } - __syncthreads(); - cuda::Reduce1D(smem); - __syncthreads(); - AType ssum = smem[0]; - __syncthreads(); - - for (index_t i = x; i < M; i += x_size) { - val = negate ? -in[base + i*sa] : in[base + i*sa]; - out[base + i*sa] = - (i < len) ? OType(OP::Map((val - smax)/static_cast(temperature), ssum)) : OType(0.0f); - } -} - const int softmax_threads_per_block = 512; -template -__global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, IType *length, - const index_t M, const double temperature, - const int rows_per_block, const index_t total_rows) { - __shared__ AType scratch[softmax_threads_per_block]; - __shared__ LType persistent_storage[20 * 1024 / sizeof(LType)]; - const int warp_size = 32; - const int threads_per_row = softmax_threads_per_block / rows_per_block; - const int my_local_row = threadIdx.x / threads_per_row; - const int my_row = blockIdx.x * rows_per_block + my_local_row; - if (my_row >= total_rows) return; - const int my_id = threadIdx.x % threads_per_row; - const int entries_per_load = sizeof(LType)/sizeof(DType); - const index_t len = length == nullptr ? M : static_cast(length[my_row]); - // Due to usage of MSHADOW_TYPE_SWITCH macro we are generating - // kernels where sizeof(LType) may be less than sizeof(DType), - // resulting in entries_per_load being 0. - // This is not a valid combination and is being checked against - // in the launcher code. This switch here is just to silence - // the division by zero warning generated for such invalid cases. - const int row_length = entries_per_load > 0 ? M / entries_per_load : 0; - - const LType* in_aligned = reinterpret_cast(in); - size_t base = my_row * row_length; - - for (index_t i = my_id; i < row_length; i += threads_per_row) { - persistent_storage[my_local_row * row_length + i] = in_aligned[base + i]; - } - DType * row = reinterpret_cast(persistent_storage + my_local_row * row_length); - __syncthreads(); - - DType my_max_value; - red::maximum::SetInitValue(my_max_value); - - for (index_t i = my_id; i < len; i += threads_per_row) { - my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]); - } - scratch[threadIdx.x] = my_max_value; - __syncthreads(); - for (int size = threads_per_row / 2; size >= warp_size; size /= 2) { - if (my_id < size) { - scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + size]); - } - __syncthreads(); - } - if (my_id < warp_size) { - AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x], - [](AType x, AType y) { return ::max(x, y); }); - scratch[threadIdx.x] = my_value; - } - __syncthreads(); - DType smax = scratch[threadIdx.x - threadIdx.x % threads_per_row]; - __syncthreads(); - - AType my_sum; - red::sum::SetInitValue(my_sum); - - for (index_t i = my_id; i < len; i += threads_per_row) { - const DType val = negate ? -row[i] : row[i]; - my_sum += static_cast(expf((val - smax) / static_cast(temperature))); - } - scratch[threadIdx.x] = my_sum; - __syncthreads(); - for (int size = threads_per_row / 2; size >= warp_size; size /= 2) { - if (my_id < size) { - scratch[threadIdx.x] += scratch[threadIdx.x + size]; - } - __syncthreads(); - } - if (my_id < warp_size) { - AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x], - [](AType x, AType y) { return x + y;}); - scratch[threadIdx.x] = my_value; - } - __syncthreads(); - - AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row]; - __syncthreads(); - - for (index_t i = my_id; i < M; i += threads_per_row) { - const DType val = negate ? -row[i] : row[i]; - row[i] = (i < len) ? DType(OP::Map((val - smax)/static_cast(temperature), ssum)) : - DType(0.0f); - } - __syncthreads(); - - LType* out_aligned = reinterpret_cast(out); - - for (index_t i = my_id; i < row_length; i += threads_per_row) { - out_aligned[base + i] = persistent_storage[my_local_row * row_length + i]; - } -} - template MSHADOW_XINLINE index_t get_mask_position(const index_t idx, const Shape& data_shape, const Shape& mask_shape, int axis, index_t* stride_axis) { @@ -665,45 +529,6 @@ __global__ void masked_softmax_stride1_kernel(const DType *in, DType *out, bool } } -template -inline void Softmax(Stream *s, DType *in, OType *out, IType *length, - Shape shape, int axis, const double temperature) { - const int x_bits = 7; - const int x_size = 1 << x_bits; - index_t M = shape[axis]; - if (M == 0 || shape.Size() == 0) return; - index_t N = shape.Size()/M; - Shape stride = calc_stride(shape); - Shape sshape = shape; - sshape[axis] = 1; - - const size_t DSize = sizeof(DType); - // Using 20 kB of shared memory for persistent storage in the optimized case - const size_t max_opt_M = 20 * 1024 / DSize; - if (stride[axis] == 1 && - static_cast(M) <= max_opt_M && - std::is_same::value) { - int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType)); - MXNET_LOAD_TYPE_SWITCH(ltype, LType, { - int rows_per_block = mxnet::common::cuda::get_rows_per_block(M * - sizeof(DType) / sizeof(LType), - softmax_threads_per_block); - int nblocks = (N + rows_per_block - 1) / rows_per_block; - CHECK_LE(sizeof(DType), sizeof(LType)); - softmax_stride1_compute_kernel - <<::GetStream(s)>>>( - in, out, length, M, temperature, rows_per_block, N); - }); - MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_compute_kernel); - } else { - softmax_compute_kernel - <<::GetStream(s)>>>( - in, out, length, M, axis, sshape, stride, temperature); - MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); - } -} - template inline void MaskedSoftmax(Stream *s, DType *in, OType *out, bool *mask, @@ -784,120 +609,6 @@ inline void MaskedSoftmax(Stream *s, DType *in, OType *out, bool *mask, } } -template -__global__ void softmax_stride1_grad_kernel(const OType *out, const OType *ograd, - DType *igrad, const IType *length, - const index_t M, - const double temperature, - const int rows_per_block, - const index_t total_rows) { - __shared__ AType scratch[softmax_threads_per_block]; - __shared__ LType persistent_storage[20 * 1024 / sizeof(LType)]; - const int warp_size = 32; - const int threads_per_row = softmax_threads_per_block / rows_per_block; - const int my_local_row = threadIdx.x / threads_per_row; - const int my_row = blockIdx.x * rows_per_block + my_local_row; - if (my_row >= total_rows) return; - const int my_id = threadIdx.x % threads_per_row; - const int entries_per_load = sizeof(LType)/sizeof(DType); - const index_t len = length == nullptr ? M : static_cast(length[my_row]); - // Due to usage of MSHADOW_TYPE_SWITCH macro we are generating - // kernels where sizeof(LType) may be less than sizeof(DType), - // resulting in entries_per_load being 0. - // This is not a valid combination and is being checked against - // in the launcher code. This switch here is just to silence - // the division by zero warning generated for such invalid cases. - const int row_length = entries_per_load > 0 ? M / entries_per_load : 0; - - const LType* out_aligned = reinterpret_cast(out); - const LType* ograd_aligned = reinterpret_cast(ograd); - size_t base = my_row * row_length; - - for (index_t i = my_id; i < row_length; i += threads_per_row) { - persistent_storage[my_local_row * row_length * 2 + i] = out_aligned[base + i]; - persistent_storage[my_local_row * row_length * 2 + row_length + i] = ograd_aligned[base + i]; - } - DType * row = reinterpret_cast(persistent_storage + my_local_row * row_length * 2); - __syncthreads(); - - AType my_sum_value; - red::sum::SetInitValue(my_sum_value); - - for (index_t i = my_id; i < len; i += threads_per_row) { - my_sum_value += OP1::Map(row[i + M], row[i]); - } - scratch[threadIdx.x] = my_sum_value; - __syncthreads(); - for (int size = threads_per_row / 2; size >= warp_size; size /= 2) { - if (my_id < size) { - scratch[threadIdx.x] = scratch[threadIdx.x] + scratch[threadIdx.x + size]; - } - __syncthreads(); - } - if (my_id < warp_size) { - AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x], - [](AType x, AType y) { return x + y; }); - scratch[threadIdx.x] = my_value; - } - __syncthreads(); - AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row]; - __syncthreads(); - - for (index_t i = my_id; i < M; i += threads_per_row) { - const DType val = - negate ? - -OP2::Map(row[i + M], row[i], ssum) : - OP2::Map(row[i + M], row[i], ssum); - row[i] = (i < len) ? DType(val / static_cast(temperature)) : - DType(0.0f); - if (Req == kAddTo) { - row[i] += igrad[my_row * M + i]; - } - } - __syncthreads(); - - LType* igrad_aligned = reinterpret_cast(igrad); - - for (index_t i = my_id; i < row_length; i += threads_per_row) { - igrad_aligned[base + i] = persistent_storage[my_local_row * row_length * 2 + i]; - } -} - -template -__global__ void softmax_grad_kernel(OType *out, OType *ograd, DType *igrad, - const IType *length, index_t M, int axis, - Shape sshape, Shape stride, - const double temperature) { - const unsigned x_size = 1 << x_bits; - __shared__ AType smem[x_size]; - index_t sa = stride[axis]; - index_t base = unravel_dot(blockIdx.x, sshape, stride); - index_t x = threadIdx.x; - index_t len = length != nullptr ? static_cast(length[blockIdx.x]) : M; - - red::sum::SetInitValue(smem[x]); - for (index_t i = x; i < len; i += x_size) { - smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]); - } - __syncthreads(); - cuda::Reduce1D(smem); - __syncthreads(); - AType ssum = smem[0]; - __syncthreads(); - - DType final_result; - for (index_t i = x; i < M; i += x_size) { - final_result = - negate ? - -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) : - OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum); - final_result = (i < len) ? final_result : DType(0.0f); - KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result / static_cast(temperature)); - } -} - template __global__ void masked_softmax_stride1_grad_kernel(const OType *out, const OType *ograd, @@ -1041,48 +752,6 @@ __global__ void masked_softmax_grad_kernel(OType *out, OType *ograd, DType *igra } } -template -inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, - DType *igrad, IType *length, Shape shape, int axis, - const double temperature) { - const int x_bits = 7; - const int x_size = 1 << x_bits; - index_t M = shape[axis]; - if (M == 0 || shape.Size() == 0) return; - index_t N = shape.Size()/M; - Shape stride = calc_stride(shape); - Shape sshape = shape; - sshape[axis] = 1; - - const size_t DSize = sizeof(DType); - // Using 20 kB of shared memory for persistent storage in the optimized case - // Need to store both out and ograd, so M can be only half compared to - // forward pass. - const size_t max_opt_M = 20 * 1024 / DSize / 2; - if (stride[axis] == 1 && - static_cast(M) <= max_opt_M && - std::is_same::value) { - int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType)); - MXNET_LOAD_TYPE_SWITCH(ltype, LType, { - int rows_per_block = mxnet::common::cuda::get_rows_per_block(M * - sizeof(DType) / sizeof(LType), - softmax_threads_per_block); - int nblocks = (N + rows_per_block - 1) / rows_per_block; - CHECK_LE(sizeof(DType), sizeof(LType)); - softmax_stride1_grad_kernel - <<::GetStream(s)>>>( - out, ograd, igrad, length, M, temperature, rows_per_block, N); - }); - MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_grad_kernel); - } else { - softmax_grad_kernel - <<::GetStream(s)>>>( - out, ograd, igrad, length, M, axis, sshape, stride, temperature); - MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_grad_kernel); - } -} - template inline void MaskedSoftmaxGrad(Stream *s, OType *out, OType *ograd, @@ -1554,6 +1223,32 @@ void MaskedSoftmaxCompute(const nnvm::NodeAttrs& attrs, }); } +#if MXNET_USE_CUDA + +struct SoftmaxRTCCompute { + std::string OP; + bool negate = false; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +struct SoftmaxRTCGradCompute { + std::string OP1; + std::string OP2; + bool negate = false; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +#endif template void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/nn/softmax.cu b/src/operator/nn/softmax.cu index c75f543257c7..c8a05b840ab9 100644 --- a/src/operator/nn/softmax.cu +++ b/src/operator/nn/softmax.cu @@ -22,18 +22,791 @@ * \file softmax.cu * \brief GPU Implementation of softmax */ +#include #include "./softmax-inl.h" -#include "../tensor/elemwise_unary_op.h" +#include "../../common/cuda/utils.h" +#include "../../common/utils.h" +#include "../../common/cuda/rtc.h" +#include "../../common/cuda/rtc/vectorization-inl.h" namespace mxnet { namespace op { +namespace { + +struct softmax_params { + const void* inputs[3]; + void* outputs[1]; + index_t stride; + index_t num_elements; + double temperature; + int rows_per_block; + index_t total_rows; +}; + +const char softmax_common_functions[] = R"code( +struct softmax_params { + const void* inputs[3]; + void* outputs[1]; + index_t stride; + index_t num_elements; + double temperature; + int rows_per_block; + index_t total_rows; +}; + +template +__device__ inline type_util::mixed_type +softmax_fwd(const DType a, const DType2 b) { + return op::exp(a) / b; +} + +template +__device__ inline type_util::mixed_type +log_softmax_fwd(const DType a, const DType2 b) { + return a - op::log(b); +} + +template +__device__ inline type_util::mixed_type +softmax_bwd(DType ograd, DType2 out, DType3 sum) { + return out * (ograd - sum); +} + +template +__device__ inline type_util::mixed_type +log_softmax_bwd(DType ograd, DType2 out, DType3 sum) { + return ograd - op::exp(out) * sum; +} + +)code"; + +const char simple_softmax_kernel_fwd[] = R"code( +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void simple_softmax_kernel(const softmax_params param, + const index_t lead_dim) { + using LengthType = AccType; + const InputType0* input = reinterpret_cast(param.inputs[0]); + const InputType1* length = reinterpret_cast(param.inputs[1]); + const index_t len = length == nullptr + ? lead_dim + : static_cast(LengthType::from(length[blockIdx.x])); + const int my_row = threadIdx.x % param.rows_per_block; + const int my_id = threadIdx.x / param.rows_per_block; + const int threads_per_row = blockDim.x / param.rows_per_block; + const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % param.stride; + const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / param.stride; + const index_t base = base_x + param.stride * lead_dim * base_n; + if (base >= param.num_elements * param.total_rows) return; + using IType = AccType; + using OType = AccType; + using AType = type_util::mixed_type; + __shared__ AType smem[kRTCMaxThreadsPerBlock]; + AType max; + red::maximum::SetInitValue(max); + for (index_t i = my_id; i < len; i += threads_per_row) { + auto val = IType::from(input[base + i * param.stride]); + max = op::max(max, negate ? -val : val); + } + smem[threadIdx.x] = max; + __syncthreads(); + for (int size = blockDim.x / 2; size >= warp_size; size /= 2) { + if (threadIdx.x < size) { + smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]); + } + __syncthreads(); + } + if (threadIdx.x < warp_size) { + AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x], + [](AType x, AType y) + { return op::max(x, y); }, + param.rows_per_block); + smem[threadIdx.x] = my_value; + } + __syncthreads(); + AType smax = smem[my_row]; + __syncthreads(); + + AType sum; + red::sum::SetInitValue(sum); + for (index_t i = my_id; i < len; i += threads_per_row) { + auto val = IType::from(input[base + i * param.stride]); + val = negate ? -val :val; + sum += op::exp((val - smax) / static_cast(param.temperature)); + } + smem[threadIdx.x] = sum; + __syncthreads(); + for (int size = blockDim.x / 2; size >= warp_size; size /= 2) { + if (threadIdx.x < size) { + smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]); + } + __syncthreads(); + } + if (threadIdx.x < warp_size) { + AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x], + [](AType x, AType y) + { return op::add(x, y); }, + param.rows_per_block); + smem[threadIdx.x] = my_value; + } + __syncthreads(); + sum = smem[my_row]; + __syncthreads(); + + OutputType0* output = reinterpret_cast(param.outputs[0]); + for (index_t i = my_id; i < lead_dim; i += threads_per_row) { + auto val = IType::from(input[base + i * param.stride]); + val = negate ? -val : val; + val = (i < len) ? OP((val - smax)/static_cast(param.temperature), sum) : 0; + if (req == OpReqType::kAddTo) { + if (i < len) { + output[base + i * param.stride] = OType::to(val + + OType::from(output[base + i * param.stride])); + } + } else { + output[base + i * param.stride] = OType::to(val); + } + } +} +)code"; + +const char softmax_stride1_kernel_fwd[] = R"code( +__launch_bounds__(vector::vectorized_kernel_thread_num) +__global__ void softmax_stride1_compute_kernel(const softmax_params param, + const index_t total_length, + const index_t other_dim, + const index_t N, + const index_t num_aligned_elements) { + using namespace vector; + using IType = AccType; + using OType = AccType; + using LengthType = AccType; + const InputType1* length = reinterpret_cast(param.inputs[1]); + using AType = type_util::mixed_type; + __shared__ AType scratch[vectorized_kernel_thread_num]; + __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)]; + const int threads_per_row = vectorized_kernel_thread_num / param.rows_per_block; + const int my_local_row = threadIdx.x / threads_per_row; + const int base_row = blockIdx.x * param.rows_per_block; + const int my_row = base_row + my_local_row; + const index_t len = (length == nullptr || + my_row >= param.total_rows) ? param.num_elements + : LengthType::from(length[my_row]); + const int my_id = threadIdx.x % threads_per_row; + + AType* row; + if (only_full_blocks || blockIdx.x < gridDim.x - 1) { + // full rows_per_block rows to compute + VectorizedLoader loader( + reinterpret_cast(param.inputs[0]) + base_row * param.num_elements, + total_length); + for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) { + loader.load(i, total_length); +#pragma unroll + for (int j = 0; j < nvec; ++j) { + persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]); + } + } + row = persistent_storage + my_local_row * param.num_elements + loader.alignment(); + } else { + // less than rows_per_block rows to compute + const index_t real_length = min(total_length, + (param.total_rows - base_row) * param.num_elements); + VectorizedLoader loader( + reinterpret_cast(param.inputs[0]) + base_row * param.num_elements, + real_length); + for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) { + loader.load(i, real_length); +#pragma unroll + for (int j = 0; j < nvec; ++j) { + persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]); + } + } + row = persistent_storage + my_local_row * param.num_elements + loader.alignment(); + } + __syncthreads(); + + AType my_max_value; + red::maximum::SetInitValue(my_max_value); + + for (index_t i = my_id; i < len; i += threads_per_row) { + my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]); + } + AType smax; + if (!reduction_inside_warp) { + scratch[threadIdx.x] = my_max_value; + __syncthreads(); + for (int size = threads_per_row / 2; size >= warp_size; size /= 2) { + if (my_id < size) { + scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + size]); + } + __syncthreads(); + } + if (my_id < warp_size) { + AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x], + [](AType x, AType y) { return op::max(x, y); }, + min(threads_per_row, warp_size)); + scratch[threadIdx.x] = my_value; + } + __syncthreads(); + smax = scratch[threadIdx.x - my_id]; + __syncthreads(); + } else { + smax = util::grouped_warp_allreduce(my_max_value, + [](AType x, AType y) { return op::max(x, y); }, + threads_per_row); + } + + AType my_sum; + red::sum::SetInitValue(my_sum); + + for (index_t i = my_id; i < len; i += threads_per_row) { + const AType val = negate ? -row[i] : row[i]; + my_sum += op::exp((val - smax) / static_cast(param.temperature)); + } + AType ssum; + if (!reduction_inside_warp) { + scratch[threadIdx.x] = my_sum; + __syncthreads(); + for (int size = threads_per_row / 2; size >= warp_size; size /= 2) { + if (my_id < size) { + scratch[threadIdx.x] += scratch[threadIdx.x + size]; + } + __syncthreads(); + } + if (my_id < warp_size) { + AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x], + [](AType x, AType y) { return x + y;}, + min(threads_per_row, warp_size)); + scratch[threadIdx.x] = my_value; + } + __syncthreads(); + + ssum = scratch[threadIdx.x - my_id]; + __syncthreads(); + } else { + ssum = util::grouped_warp_allreduce(my_sum, + [](AType x, AType y) { return x + y;}, + threads_per_row); + } + + for (index_t i = my_id; i < param.num_elements; i += threads_per_row) { + const AType val = negate ? -row[i] : row[i]; + row[i] = (i < len) ? OP((val - smax)/static_cast(param.temperature), ssum) : + 0; + } + __syncthreads(); + + if (only_full_blocks || blockIdx.x < gridDim.x - 1) { + VectorizedStorer storer( + reinterpret_cast(param.outputs[0]) + base_row * param.num_elements, + total_length); + + for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) { + if (req == OpReqType::kAddTo) { + storer.load(i, total_length); +#pragma unroll + for (int j = 0; j < nvec; ++j) { + storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j], + OType::from(storer.separate()[j]))); + } + } else { +#pragma unroll + for (int j = 0; j < nvec; ++j) { + storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]); + } + } + storer.store(i, total_length); + } + } else { + const index_t real_length = min(total_length, + (param.total_rows - base_row) * param.num_elements); + VectorizedStorer storer( + reinterpret_cast(param.outputs[0]) + base_row * param.num_elements, + real_length); + + for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) { + if (req == OpReqType::kAddTo) { + storer.load(i, real_length); +#pragma unroll + for (int j = 0; j < nvec; ++j) { + storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j], + OType::from(storer.separate()[j]))); + } + } else { +#pragma unroll + for (int j = 0; j < nvec; ++j) { + storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]); + } + } + storer.store(i, real_length); + } + } +} +)code"; + +int get_rows_per_block(const index_t row_size, const int nvec, + const index_t max_storage, const int num_threads_per_block, + const index_t total_rows, const int dev_id) { + CHECK(common::IsPower2(num_threads_per_block)) + << "Number of threads in a block must be power of 2 to use get_rows_per_block function"; + // How many read instructions should 1 thread at least do + const int read_instructions = 16; + const size_t row_size_in_vec = (row_size + nvec - 1) / nvec; + int desired_num_threads_per_row = (row_size_in_vec + read_instructions - 1) / read_instructions; + desired_num_threads_per_row = common::RoundToPower2(desired_num_threads_per_row); + desired_num_threads_per_row = std::min(desired_num_threads_per_row, num_threads_per_block); + const int desired_rows_per_block = num_threads_per_block / desired_num_threads_per_row; + int actual_rows_per_block = desired_rows_per_block; + int num_sms = MultiprocessorCount(dev_id); + while (actual_rows_per_block > 1 && + ((max_storage != -1 && max_storage < row_size * actual_rows_per_block) || + (total_rows + actual_rows_per_block - 1) / actual_rows_per_block < num_sms)) { + actual_rows_per_block /= 2; + } + return actual_rows_per_block; +} + +} // namespace + +void SoftmaxRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using common::mshadow_type_info; + using namespace common::cuda::rtc; + using common::div_round; + if (req[0] == kNullOp || inputs[0].Size() == 0U) return; + const SoftmaxParam& param = nnvm::get(attrs.parsed); + int axis = CheckAxis(param.axis, inputs[0].ndim()); + const double temperature = param.temperature.has_value() ? + param.temperature.value() : 1.0; + mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); + + void* length_ptr = nullptr; + std::string length_typename = "int"; + if (param.use_length.value()) { + CHECK(inputs.size() > 1) + << "Mask needs to be provided when using softmax with use_length=True."; + length_ptr = inputs[1].dptr_; + length_typename = mshadow_type_info(inputs[1].type_flag_).name; + } + CHECK_EQ(outputs.size(), 1); + index_t M = shape[axis]; + if (M == 0 || shape.Size() == 0) return; + index_t stride = 1; + if (axis == shape.ndim() - 2) { + stride = shape[shape.ndim() - 1]; + } + const index_t N = shape.Size() / M; + softmax_params params = {{inputs[0].dptr_, length_ptr, nullptr}, + {outputs[0].dptr_}, + stride, M, + temperature, 1, N}; + std::string code = "#define OP " + OP + "\n" + "const OpReqType req = " + util::to_string(req[0]) + ";\n" + "const bool negate = " + std::to_string(negate) + ";\n" + "using InputType1 = " + length_typename + ";\n"; + Stream* s = ctx.get_stream(); + + constexpr int nvec = 2; + // Using 20 kB of shared memory for persistent storage in the optimized case + const size_t acc_type_size = std::max(mshadow_type_info(inputs[0].type_flag_).acc_size, + mshadow_type_info(outputs[0].type_flag_).acc_size); + const size_t max_opt_M = 20 * 1024 / acc_type_size; + int rows_per_block = get_rows_per_block(M, nvec, max_opt_M, + vectorized_kernel_thread_num, + N, ctx.run_ctx.ctx.dev_id); + constexpr int warp_size = common::cuda::warp_size; + if (stride == 1 && + static_cast(M * rows_per_block) <= max_opt_M) { + code += "const bool only_full_blocks = " + std::to_string(N % rows_per_block == 0) + ";\n" + "const bool reduction_inside_warp = " + + std::to_string(vectorized_kernel_thread_num / rows_per_block <= warp_size) + ";\n"; + params.rows_per_block = rows_per_block; + int nblocks = (N + rows_per_block - 1) / rows_per_block; + VectorizedKernelRTCLauncher(code + softmax_common_functions, "softmax_stride1_compute_kernel", + softmax_stride1_kernel_fwd, nvec, + M * rows_per_block, N / rows_per_block, s, params, + inputs, outputs, + ctx.run_ctx.ctx.dev_id, 0, nblocks); + MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_compute_kernel); + } else { + code += "using InputType0 = " + mshadow_type_info(inputs[0].type_flag_).name + ";\n" + "using OutputType0 = " + mshadow_type_info(outputs[0].type_flag_).name + ";\n"; + std::vector args; + args.emplace_back(¶ms); + args.emplace_back(&M); + int num_threads = std::min(static_cast(128), + common::RoundToPower2(div_round(M, warp_size) * warp_size)); + if (stride != 1) { + const int num_sms = MultiprocessorCount(ctx.run_ctx.ctx.dev_id); + const index_t rows_per_sm = div_round(N, (512 / num_threads) * num_sms); + params.rows_per_block = std::min(static_cast(warp_size), + common::RoundToPower2(rows_per_sm)); + } + const auto& kernel_func = get_function(code + softmax_common_functions, + "simple_softmax_kernel", + simple_softmax_kernel_fwd, + ctx.run_ctx.ctx.dev_id); + launch(kernel_func, div_round(N, params.rows_per_block), num_threads, 0, s, &args); + MSHADOW_CUDA_POST_KERNEL_CHECK(simple_softmax_kernel); + } +} + +const char simple_softmax_kernel_bwd[] = R"code( +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void simple_softmax_grad_kernel(const softmax_params param, + const index_t lead_dim) { + using LengthType = AccType; + const InputType0* out = reinterpret_cast(param.inputs[0]); + const InputType1* ograd = reinterpret_cast(param.inputs[1]); + const InputType2* length = reinterpret_cast(param.inputs[2]); + const index_t len = length == nullptr + ? lead_dim + : static_cast(LengthType::from(length[blockIdx.x])); + const int my_row = threadIdx.x % param.rows_per_block; + const int my_id = threadIdx.x / param.rows_per_block; + const int threads_per_row = blockDim.x / param.rows_per_block; + const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % param.stride; + const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / param.stride; + const index_t base = base_x + param.stride * lead_dim * base_n; + if (base >= param.num_elements * param.total_rows) return; + using IType0 = AccType; + using IType1 = AccType; + using OType = AccType; + using AType = type_util::mixed_type; + __shared__ AType smem[kRTCMaxThreadsPerBlock]; + AType sum; + red::sum::SetInitValue(sum); + for (index_t i = my_id; i < len; i += threads_per_row) { + auto out_val = IType0::from(out[base + i * param.stride]); + auto ograd_val = IType1::from(ograd[base + i * param.stride]); + sum += OP1(ograd_val, out_val); + } + smem[threadIdx.x] = sum; + __syncthreads(); + for (int size = blockDim.x / 2; size >= warp_size; size /= 2) { + if (threadIdx.x < size) { + smem[threadIdx.x] = smem[threadIdx.x] + smem[threadIdx.x + size]; + } + __syncthreads(); + } + if (threadIdx.x < warp_size) { + AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x], + [](AType x, AType y) { return x + y; }, + param.rows_per_block); + smem[threadIdx.x] = my_value; + } + __syncthreads(); + sum = smem[my_row]; + __syncthreads(); + + OutputType0* igrad = reinterpret_cast(param.outputs[0]); + for (index_t i = my_id; i < lead_dim; i += threads_per_row) { + auto out_val = IType0::from(out[base + i * param.stride]); + auto ograd_val = IType1::from(ograd[base + i * param.stride]); + auto val = (i < len) ? OP2(ograd_val, out_val, sum) / static_cast(param.temperature) : 0; + val = negate ? -val : val; + if (req == OpReqType::kAddTo) { + if (i < len) { + igrad[base + i * param.stride] = OType::to(val + + OType::from(igrad[base + i * param.stride])); + } + } else { + igrad[base + i * param.stride] = OType::to(val); + } + } +} +)code"; + +const char softmax_stride1_kernel_bwd[] = R"code( +__launch_bounds__(vector::vectorized_kernel_thread_num) +__global__ void softmax_stride1_compute_grad_kernel(const softmax_params param, + const index_t total_length, + const index_t other_dim, + const index_t N, + const index_t num_aligned_elements) { + using namespace vector; + using IType0 = AccType; + using IType1 = AccType; + using OType = AccType; + using LengthType = AccType; + const InputType2* length = reinterpret_cast(param.inputs[2]); + using AType = type_util::mixed_type; + __shared__ AType scratch[vectorized_kernel_thread_num]; + __shared__ AType output_persistent_storage[10 * 1024 / sizeof(AType)]; + __shared__ AType ograd_persistent_storage[10 * 1024 / sizeof(AType)]; + const int warp_size = 32; + const int threads_per_row = vectorized_kernel_thread_num / param.rows_per_block; + const int my_local_row = threadIdx.x / threads_per_row; + const int base_row = blockIdx.x * param.rows_per_block; + const int my_row = base_row + my_local_row; + const index_t len = (length == nullptr || + my_row >= param.total_rows) ? param.num_elements + : LengthType::from(length[my_row]); + const int my_id = threadIdx.x % threads_per_row; + + AType* output_row; + AType* ograd_row; + if (only_full_blocks || blockIdx.x < gridDim.x - 1) { + // full rows_per_block rows to compute + VectorizedLoader output_loader( + reinterpret_cast(param.inputs[0]) + base_row * param.num_elements, + total_length); + VectorizedLoader ograd_loader( + reinterpret_cast(param.inputs[1]) + base_row * param.num_elements, + total_length); + for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) { + output_loader.load(i, total_length); + ograd_loader.load(i, total_length); +#pragma unroll + for (int j = 0; j < nvec; ++j) { + output_persistent_storage[i*nvec + j] = IType0::from(output_loader.separate()[j]); + ograd_persistent_storage[i*nvec + j] = IType1::from(ograd_loader.separate()[j]); + } + } + output_row = output_persistent_storage + + my_local_row * param.num_elements + + output_loader.alignment(); + ograd_row = ograd_persistent_storage + + my_local_row * param.num_elements + + ograd_loader.alignment(); + } else { + // less than rows_per_block rows to compute + const index_t real_length = min(total_length, + (param.total_rows - base_row) * param.num_elements); + VectorizedLoader output_loader( + reinterpret_cast(param.inputs[0]) + base_row * param.num_elements, + real_length); + VectorizedLoader ograd_loader( + reinterpret_cast(param.inputs[1]) + base_row * param.num_elements, + real_length); + for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) { + output_loader.load(i, real_length); + ograd_loader.load(i, real_length); +#pragma unroll + for (int j = 0; j < nvec; ++j) { + output_persistent_storage[i*nvec + j] = IType0::from(output_loader.separate()[j]); + ograd_persistent_storage[i*nvec + j] = IType1::from(ograd_loader.separate()[j]); + } + } + output_row = output_persistent_storage + + my_local_row * param.num_elements + + output_loader.alignment(); + ograd_row = ograd_persistent_storage + + my_local_row * param.num_elements + + ograd_loader.alignment(); + } + __syncthreads(); + + AType my_sum; + red::sum::SetInitValue(my_sum); + + for (index_t i = my_id; i < len; i += threads_per_row) { + const AType val = OP1(ograd_row[i], output_row[i]); + my_sum += val; + } + AType ssum; + if (!reduction_inside_warp) { + scratch[threadIdx.x] = my_sum; + __syncthreads(); + for (int size = threads_per_row / 2; size >= warp_size; size /= 2) { + if (my_id < size) { + scratch[threadIdx.x] += scratch[threadIdx.x + size]; + } + __syncthreads(); + } + if (my_id < warp_size) { + AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x], + [](AType x, AType y) { return x + y;}, + min(threads_per_row, warp_size)); + scratch[threadIdx.x] = my_value; + } + __syncthreads(); + + ssum = scratch[threadIdx.x - my_id]; + __syncthreads(); + } else { + ssum = util::grouped_warp_allreduce(my_sum, + [](AType x, AType y) { return x + y;}, + threads_per_row); + } + + for (index_t i = my_id; i < param.num_elements; i += threads_per_row) { + AType val = (i < len) + ? OP2(ograd_row[i], output_row[i], ssum) / static_cast(param.temperature) + : 0; + output_row[i] = negate ? -val : val; + } + __syncthreads(); + + if (only_full_blocks || blockIdx.x < gridDim.x - 1) { + VectorizedStorer storer( + reinterpret_cast(param.outputs[0]) + base_row * param.num_elements, + total_length); + + for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) { + if (req == OpReqType::kAddTo) { + storer.load(i, total_length); +#pragma unroll + for (int j = 0; j < nvec; ++j) { + storer.separate()[j] = OType::to(op::add(output_persistent_storage[i*nvec + j], + OType::from(storer.separate()[j]))); + } + } else { +#pragma unroll + for (int j = 0; j < nvec; ++j) { + storer.separate()[j] = OType::to(output_persistent_storage[i*nvec + j]); + } + } + storer.store(i, total_length); + } + } else { + const index_t real_length = min(total_length, + (param.total_rows - base_row) * param.num_elements); + VectorizedStorer storer( + reinterpret_cast(param.outputs[0]) + base_row * param.num_elements, + real_length); + + for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) { + if (req == OpReqType::kAddTo) { + storer.load(i, real_length); +#pragma unroll + for (int j = 0; j < nvec; ++j) { + storer.separate()[j] = OType::to(op::add(output_persistent_storage[i*nvec + j], + OType::from(storer.separate()[j]))); + } + } else { +#pragma unroll + for (int j = 0; j < nvec; ++j) { + storer.separate()[j] = OType::to(output_persistent_storage[i*nvec + j]); + } + } + storer.store(i, real_length); + } + } +} +)code"; + +void SoftmaxRTCGradCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using common::mshadow_type_info; + using namespace common::cuda::rtc; + using common::div_round; + Stream* s = ctx.get_stream(); + if (softmax_use_length(attrs)) { + if (req[1] != kNullOp) { + cudaMemsetAsync(outputs[1].dptr_, 0, + outputs[1].Size() * mshadow_type_info(outputs[1].type_flag_).size, + Stream::GetStream(s)); + } + } + if (req[0] == kNullOp || inputs[0].Size() == 0U) return; + const SoftmaxParam& param = nnvm::get(attrs.parsed); + int axis = CheckAxis(param.axis, inputs[0].ndim()); + const double temperature = param.temperature.has_value() ? + param.temperature.value() : 1.0; + mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); + + int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1; + out_idx = softmax_use_length(attrs) ? 3 : out_idx; + + void* length_ptr = nullptr; + std::string length_typename = "int"; + if (softmax_use_length(attrs)) { + length_ptr = inputs[2].dptr_; + length_typename = mshadow_type_info(inputs[2].type_flag_).name; + } + index_t M = shape[axis]; + if (M == 0 || shape.Size() == 0) return; + index_t stride = 1; + if (axis == shape.ndim() - 2) { + stride = shape[shape.ndim() - 1]; + } + const index_t N = shape.Size() / M; + softmax_params params = {{inputs[out_idx].dptr_, inputs[0].dptr_, length_ptr}, + {outputs[0].dptr_}, + stride, M, + temperature, 1, N}; + std::string code = "#define OP1 " + OP1 + "\n" + "#define OP2 " + OP2 + "\n" + "const OpReqType req = " + util::to_string(req[0]) + ";\n" + "const bool negate = " + std::to_string(negate) + ";\n" + "using InputType2 = " + length_typename + ";\n"; + + constexpr int nvec = 2; + // Using 20 kB of shared memory for persistent storage in the optimized case + const size_t acc_type_size = std::max(mshadow_type_info(inputs[0].type_flag_).acc_size, + mshadow_type_info(outputs[0].type_flag_).acc_size); + const size_t max_opt_M = 10 * 1024 / acc_type_size; + int rows_per_block = get_rows_per_block(M, nvec, max_opt_M, + vectorized_kernel_thread_num, + N, ctx.run_ctx.ctx.dev_id); + params.rows_per_block = rows_per_block; + bool debug_softmax = dmlc::GetEnv("DEBUG_SOFTMAX_GRAD", false); + if (!debug_softmax && stride == 1 && + static_cast(M * rows_per_block) <= max_opt_M) { + const int warp_size = 32; + code += "const bool only_full_blocks = " + std::to_string(N % rows_per_block == 0) + ";\n" + "const bool reduction_inside_warp = " + + std::to_string(vectorized_kernel_thread_num / rows_per_block <= warp_size) + ";\n"; + int nblocks = div_round(N, rows_per_block); + std::vector new_inputs = {inputs[out_idx], inputs[0]}; + if (softmax_use_length(attrs)) { + new_inputs.emplace_back(inputs[2]); + } + std::vector new_outputs = {outputs[0]}; + VectorizedKernelRTCLauncher(code + softmax_common_functions, + "softmax_stride1_compute_grad_kernel", + softmax_stride1_kernel_bwd, nvec, + M * rows_per_block, N / rows_per_block, s, params, + new_inputs, new_outputs, + ctx.run_ctx.ctx.dev_id, 0, nblocks); + MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_stride1_compute_grad_kernel); + } else { + code += "using InputType0 = " + mshadow_type_info(inputs[out_idx].type_flag_).name + ";\n" + "using InputType1 = " + mshadow_type_info(inputs[0].type_flag_).name + ";\n" + "using OutputType0 = " + mshadow_type_info(outputs[0].type_flag_).name + ";\n"; + std::vector args; + args.emplace_back(¶ms); + args.emplace_back(&M); + const int warp_size = 32; + int num_threads = std::min(static_cast(128), + common::RoundToPower2(div_round(M, warp_size) * warp_size)); + if (stride != 1) { + const int num_sms = MultiprocessorCount(ctx.run_ctx.ctx.dev_id); + const index_t rows_per_sm = div_round(N, (512 / num_threads) * num_sms); + params.rows_per_block = std::min(static_cast(warp_size), + common::RoundToPower2(rows_per_sm)); + } + const auto& kernel_func = get_function(code + softmax_common_functions, + "simple_softmax_grad_kernel", + simple_softmax_kernel_bwd, + ctx.run_ctx.ctx.dev_id); + launch(kernel_func, div_round(N, params.rows_per_block), num_threads, 0, s, &args); + MSHADOW_CUDA_POST_KERNEL_CHECK(simple_softmax_grad_kernel); + } +} + NNVM_REGISTER_OP(softmax) -.set_attr("FCompute", SoftmaxCompute); +.set_attr("FCompute", SoftmaxRTCCompute{"softmax_fwd"}); NNVM_REGISTER_OP(_backward_softmax) -.set_attr("FCompute", SoftmaxGradCompute); +.set_attr("FCompute", SoftmaxRTCGradCompute{"op::mul", "softmax_bwd"}); + NNVM_REGISTER_OP(masked_softmax) .set_attr("FCompute", MaskedSoftmaxCompute); diff --git a/src/operator/nn/softmin.cu b/src/operator/nn/softmin.cu index d00d0bdad231..b6f56ced8dd0 100644 --- a/src/operator/nn/softmin.cu +++ b/src/operator/nn/softmin.cu @@ -29,11 +29,10 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(softmin) -.set_attr("FCompute", SoftmaxCompute); +.set_attr("FCompute", SoftmaxRTCCompute{"softmax_fwd", true}); NNVM_REGISTER_OP(_backward_softmin) -.set_attr("FCompute", SoftmaxGradCompute); +.set_attr("FCompute", SoftmaxRTCGradCompute{"op::mul", "softmax_bwd", true}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op.cc b/src/operator/tensor/elemwise_binary_scalar_op.cc index f09bf21cceb4..d7fde47fe3d1 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op.cc @@ -50,6 +50,7 @@ __global__ void binary_scalar_kernel(const binary_scalar_kernel_params params, const index_t N, const index_t num_aligned_elements) { using namespace vector; + using type_util::mixed_type; VectorizedLoader loader( reinterpret_cast(params.inputs[0]), N); VectorizedStorer storer( @@ -72,9 +73,8 @@ __global__ void binary_scalar_kernel(const binary_scalar_kernel_params params, const auto input = IType::from(loader.separate()[i]); // enables returning different type const auto temp = OP(input, - static_cast::type> - (params.scalar)); + static_cast>(params.scalar)); if (req == OpReqType::kAddTo) { // temp2 may have a wider type than either temp @@ -171,6 +171,7 @@ __global__ void binary_scalar_kernel_bwd(const binary_scalar_kernel_params param const index_t N, const index_t num_aligned_elements) { using namespace vector; + using type_util::mixed_type; VectorizedLoader ograd_loader( reinterpret_cast(params.inputs[0]), N); VectorizedLoader input_loader( @@ -199,9 +200,8 @@ __global__ void binary_scalar_kernel_bwd(const binary_scalar_kernel_params param // enables returning different type const auto temp = op::mul(ograd, OP(input, - static_cast - ::type>(params.scalar))); + static_cast>(params.scalar))); if (req == OpReqType::kAddTo) { // temp2 may have a wider type than either temp