diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h index e6f8915ab2ff..2705177f951d 100644 --- a/src/operator/nn/activation-inl.h +++ b/src/operator/nn/activation-inl.h @@ -174,7 +174,7 @@ void ActivationGradComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ct break; case activation::kSoftSign: ActivationBackward( - ctx, inputs[0], inputs[1], req[0], outputs[0]); + ctx, inputs[0], inputs[2], req[0], outputs[0]); break; default: LOG(FATAL) << "unknown activation type"; @@ -198,12 +198,13 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { -#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1) const ActivationParam& param = nnvm::get(attrs.parsed); +#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1) bool relu = param.act_type == activation::kReLU; CHECK_EQ(inputs.size(), relu ? 2U : 3U); #else - CHECK_EQ(inputs.size(), 2U); + bool softsign = param.act_type == activation::kSoftSign; + CHECK_EQ(inputs.size(), softsign ? 3U : 2U); #endif CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index d723bbe62d76..3404b5b700af 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -44,11 +44,19 @@ struct ActivationGrad { const std::vector& ograds) const { std::vector heads(ograds.begin(), ograds.end()); heads.emplace_back(nnvm::NodeEntry{n, activation::kOut, 0}); -#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1) + const NodeAttrs& attrs = n->attrs; + int act_type = dmlc::get(attrs.parsed).act_type; + if (act_type == activation::kSoftSign) { + // for softsign need the inputs to compute the activation. + heads.push_back(n->inputs[activation::kData]); + } + +#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1) // for ReLU, no need to pass input data. This enables inplace optimization during the // forward pass. - if (dmlc::get(attrs.parsed).act_type != activation::kReLU) { + if (act_type != activation::kReLU && + act_type != activation::kSoftSign) { heads.push_back(n->inputs[activation::kData]); } #endif @@ -118,8 +126,8 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { bool ret = false; -#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1) const ActivationParam& param = nnvm::get(attrs.parsed); +#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1) if (param.act_type != activation::kReLU) { CHECK_EQ(in_attrs->size(), 3U); ret = ElemwiseStorageType<3, 1, false, false, false>(attrs, dev_mask, @@ -133,10 +141,17 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs, in_attrs, out_attrs); } #else - CHECK_EQ(in_attrs->size(), 2U); - ret = ElemwiseStorageType<2, 1, false, false, false>(attrs, dev_mask, - dispatch_mode, - in_attrs, out_attrs); + if (param.act_type == activation::kSoftSign) { + CHECK_EQ(in_attrs->size(), 3U); + ret = ElemwiseStorageType<3, 1, false, false, false>(attrs, dev_mask, + dispatch_mode, + in_attrs, out_attrs); + } else { + CHECK_EQ(in_attrs->size(), 2U); + ret = ElemwiseStorageType<2, 1, false, false, false>(attrs, dev_mask, + dispatch_mode, + in_attrs, out_attrs); + } #endif CHECK_EQ(out_attrs->size(), 1U); #if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu index 68b4053efdda..8892cc34f710 100644 --- a/src/operator/nn/activation.cu +++ b/src/operator/nn/activation.cu @@ -87,7 +87,7 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, ctx, inputs[0], inputs[1], req[0], outputs[0]); } else if (param.act_type == activation::kSoftSign) { ActivationBackward( - ctx, inputs[0], inputs[1], req[0], outputs[0]); + ctx, inputs[0], inputs[2], req[0], outputs[0]); } else { MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { // XXX: for y = relu(x), y is passed as "in_data" to Backward() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index c8707097dd35..add9ae26932d 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6850,6 +6850,10 @@ def test_activation(): lambda x: np.log(1. + np.exp(x)), lambda x: 1. - 1 / (1 + np.exp(x)), -3.0, 3.0], + 'softsign': [lambda x: mx.sym.Activation(x, act_type='softsign'), + lambda x: x / (1. + np.abs(x)), + lambda x: 1. / np.square(1. + np.abs(x)), + -3.0, 3.0], } # Loop over operators for name, op in unary_ops.items():