Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/common/cuda/rtc/backward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,6 @@ backward_log_sigmoid(const DTypeGrad grad, const DType val) {
return grad * 1 / (1 + op::exp(val));
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_mish(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
const auto softrelu = op::log(1 + exp(v));
const auto tanh = op::tanh(softrelu);
return grad * (tanh + v * sigmoid(v) * (1 - tanh * tanh));
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_softrelu(const DTypeGrad grad, const DType val) {
Expand Down Expand Up @@ -212,6 +203,14 @@ backward_arctanh(const DTypeGrad grad, const DType val) {
return grad / (1 - val * val);
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_mish(const DTypeGrad grad, const DType val) {
const auto softrelu = op::softrelu(val);
const auto tanh_sr = op::tanh(softrelu);
return grad * (tanh_sr + val * sigmoid(val) * (1 - tanh_sr * tanh_sr));
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_sqrt(const DTypeGrad grad, const DType out) {
Expand Down
14 changes: 5 additions & 9 deletions src/common/cuda/rtc/forward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -694,15 +694,6 @@ __device__ inline DType log_sigmoid(const DType val) {
}
}

template <typename DType>
__device__ inline DType mish(const DType val) {
if (type_util::has_double_or_integral<DType>::value) {
return val * ::tanh(::log(1 + ::exp(val)));
} else {
return val * ::tanhf(logf(1 + expf(val)));
}
}

template <typename DType>
__device__ inline DType softrelu(const DType val) {
// Avoid overflow of exp for large inputs.
Expand Down Expand Up @@ -780,6 +771,11 @@ DEFINE_UNARY_MATH_FUNC(arcsinh, ::asinh, ::asinhf)
DEFINE_UNARY_MATH_FUNC(arccosh, ::acosh, ::acoshf)
DEFINE_UNARY_MATH_FUNC(arctanh, ::atanh, ::atanhf)

template <typename DType>
__device__ inline DType mish(const DType val) {
return val * op::tanh(op::softrelu(val));
}

// sqrt

DEFINE_UNARY_MATH_FUNC(sqrt, ::sqrt, ::sqrtf)
Expand Down
28 changes: 24 additions & 4 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,31 @@ MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + math::exp(-a))));

MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a)));

MXNET_UNARY_MATH_OP(mish, a * math::tanh(math::log(1.0f + math::exp(a))));
struct mish : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// reference softrelu
auto softrelu = math::log1p(math::exp(a));
if (a > DType(20.0f)) {
softrelu = a;
}
return DType(a * math::tanh(softrelu));
}
};

MXNET_UNARY_MATH_OP(mish_grad, math::tanh(math::log(1.0f + math::exp(a))) +
a * (1.0f / (1.0f + math::exp(-a))) *
(1.0f - math::sqr(math::tanh(math::log(1.0f + math::exp(a))))));
struct mish_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// Note: the input(a) is x(not y)
auto softrelu = math::log1p(math::exp(a));
if (a > DType(20.0f)) {
softrelu = a;
}
auto tanh_sr = math::tanh(softrelu);
auto sr_grad = 1.0f / (1.0f + math::exp(-a));
return DType(tanh_sr + a * sr_grad * (1.0f - tanh_sr * tanh_sr));
}
};

MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));

Expand Down