diff --git a/src/common/cuda/rtc/backward_functions-inl.h b/src/common/cuda/rtc/backward_functions-inl.h index cacb8be5b4b6..85135ae6e888 100644 --- a/src/common/cuda/rtc/backward_functions-inl.h +++ b/src/common/cuda/rtc/backward_functions-inl.h @@ -50,15 +50,6 @@ backward_log_sigmoid(const DTypeGrad grad, const DType val) { return grad * 1 / (1 + op::exp(val)); } -template -__device__ inline mixed_type -backward_mish(const DTypeGrad grad, const DType val) { - const mixed_type v = val; - const auto softrelu = op::log(1 + exp(v)); - const auto tanh = op::tanh(softrelu); - return grad * (tanh + v * sigmoid(v) * (1 - tanh * tanh)); -} - template __device__ inline mixed_type backward_softrelu(const DTypeGrad grad, const DType val) { @@ -212,6 +203,14 @@ backward_arctanh(const DTypeGrad grad, const DType val) { return grad / (1 - val * val); } +template +__device__ inline mixed_type +backward_mish(const DTypeGrad grad, const DType val) { + const auto softrelu = op::softrelu(val); + const auto tanh_sr = op::tanh(softrelu); + return grad * (tanh_sr + val * sigmoid(val) * (1 - tanh_sr * tanh_sr)); +} + template __device__ inline mixed_type backward_sqrt(const DTypeGrad grad, const DType out) { diff --git a/src/common/cuda/rtc/forward_functions-inl.h b/src/common/cuda/rtc/forward_functions-inl.h index b353e92ab8f1..7a886a0a9aec 100644 --- a/src/common/cuda/rtc/forward_functions-inl.h +++ b/src/common/cuda/rtc/forward_functions-inl.h @@ -694,15 +694,6 @@ __device__ inline DType log_sigmoid(const DType val) { } } -template -__device__ inline DType mish(const DType val) { - if (type_util::has_double_or_integral::value) { - return val * ::tanh(::log(1 + ::exp(val))); - } else { - return val * ::tanhf(logf(1 + expf(val))); - } -} - template __device__ inline DType softrelu(const DType val) { // Avoid overflow of exp for large inputs. @@ -780,6 +771,11 @@ DEFINE_UNARY_MATH_FUNC(arcsinh, ::asinh, ::asinhf) DEFINE_UNARY_MATH_FUNC(arccosh, ::acosh, ::acoshf) DEFINE_UNARY_MATH_FUNC(arctanh, ::atanh, ::atanhf) +template +__device__ inline DType mish(const DType val) { + return val * op::tanh(op::softrelu(val)); +} + // sqrt DEFINE_UNARY_MATH_FUNC(sqrt, ::sqrt, ::sqrtf) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index e28dc86f971b..611ddbcad472 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -415,11 +415,31 @@ MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + math::exp(-a)))); MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a))); -MXNET_UNARY_MATH_OP(mish, a * math::tanh(math::log(1.0f + math::exp(a)))); +struct mish : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a) { + // reference softrelu + auto softrelu = math::log1p(math::exp(a)); + if (a > DType(20.0f)) { + softrelu = a; + } + return DType(a * math::tanh(softrelu)); + } +}; -MXNET_UNARY_MATH_OP(mish_grad, math::tanh(math::log(1.0f + math::exp(a))) + - a * (1.0f / (1.0f + math::exp(-a))) * - (1.0f - math::sqr(math::tanh(math::log(1.0f + math::exp(a)))))); +struct mish_grad : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a) { + // Note: the input(a) is x(not y) + auto softrelu = math::log1p(math::exp(a)); + if (a > DType(20.0f)) { + softrelu = a; + } + auto tanh_sr = math::tanh(softrelu); + auto sr_grad = 1.0f / (1.0f + math::exp(-a)); + return DType(tanh_sr + a * sr_grad * (1.0f - tanh_sr * tanh_sr)); + } +}; MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));