Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions benchmark/opperf/nd_operations/nn_activation_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
8. Activation
8.1 relu
8.2 sigmoid
8.3 softrelu
8.4 softsign
8.5 tanh
8.3 log_sigmoid
8.4 softrelu
8.5 softsign
8.6 tanh

"""

Expand Down
2 changes: 1 addition & 1 deletion benchmark/opperf/rules/default_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@

# For NN operators
DEFAULT_ACT_TYPE_LR = ['leaky', 'elu', 'selu', 'gelu']
DEFAULT_ACT_TYPE_ACTIVATION = ['relu', 'sigmoid', 'softrelu', 'softsign', 'tanh']
DEFAULT_ACT_TYPE_ACTIVATION = ['relu', 'sigmoid', 'log_sigmoid', 'softrelu', 'softsign', 'tanh']
DEFAULT_LABEL_SOFTMAX = [(1024, 1024), (10000, 1), (10000, 100)]

DEFAULT_LABEL_SOFTMAX_LARGE_TENSOR = [(2**32, 1)]
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/amp/lists/symbol_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@
'hard_sigmoid',
'identity',
'logical_not',
'log_sigmoid'
'max_axis',
'max',
'min',
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@
'lamb_update_phase1',
'lamb_update_phase2',
'logical_not',
'log_sigmoid',
'max',
'min',
'mp_lamb_update_phase1',
Expand Down
8 changes: 8 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,6 +2171,14 @@ def log1p(self, *args, **kwargs):
"""
return op.log1p(self, *args, **kwargs)

def log_sigmoid(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log_sigmoid`.

The arguments are the same as for :py:func:`log_sigmoid`, with
this array as data.
"""
return op.log_sigmoid(self, *args, **kwargs)

def sqrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sqrt`.

Expand Down
8 changes: 8 additions & 0 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2260,6 +2260,14 @@ def log1p(self, *args, **kwargs):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute log1p')

def log_sigmoid(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log_sigmoid`.

The arguments are the same as for :py:func:`log_sigmoid`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute log_sigmoid')

def sqrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sqrt`.

Expand Down
8 changes: 8 additions & 0 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2519,6 +2519,14 @@ def log1p(self, *args, **kwargs):
"""
return op.log1p(self, *args, **kwargs)

def log_sigmoid(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log_sigmoid`.

The arguments are the same as for :py:func:`log_sigmoid`, with
this array as data.
"""
return op.log_sigmoid(self, *args, **kwargs)

def sqrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sqrt`.

Expand Down
2 changes: 2 additions & 0 deletions src/api/operator/numpy_extension/npx_activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ inline int String2MXNetActType(const std::string& s) {
return activation::kReLU;
} else if (s == "sigmoid") {
return activation::kSigmoid;
} else if (s == "log_sigmoid") {
return activation::kLogSigmoid;
} else if (s == "tanh") {
return activation::kTanh;
} else if (s == "softrelu") {
Expand Down
10 changes: 8 additions & 2 deletions src/common/cuda/rtc/backward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,14 @@ backward_relu(const DTypeGrad grad, const DType val) {

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_sigmoid(const DTypeGrad grad, const DType out) {
return grad * out * (1 - out);
backward_sigmoid(const DTypeGrad grad, const DType val) {
return grad * val * (1 - val);
}

template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_log_sigmoid(const DTypeGrad grad, const DType val) {
return grad * 1 / (1 + op::exp(val));
}

template <typename DType, typename DTypeGrad>
Expand Down
9 changes: 9 additions & 0 deletions src/common/cuda/rtc/forward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,15 @@ __device__ inline DType sigmoid(const DType val) {
}
}

template <typename DType>
__device__ inline DType log_sigmoid(const DType val) {
if (type_util::has_double_or_integral<DType>::value) {
return ::log(1./(1 + ::exp(-val)));
} else {
return ::logf(1.f/(1 + expf(-val)));
}
}

template <typename DType>
__device__ inline DType softrelu(const DType val) {
if (type_util::has_double_or_integral<DType>::value) {
Expand Down
2 changes: 2 additions & 0 deletions src/operator/fusion/fused_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ const std::map<std::string, std::vector<std::vector<std::string>>> ops_desc = {
{"_backward_amp_cast" , {{"op::identity(%)", "_0"}}},
{"relu" , {{"op::relu(%)", "_0"}}},
{"sigmoid" , {{"op::sigmoid(%)", "_0"}}},
{"log_sigmoid" , {{"op::log_sigmoid(%)", "_0"}}},
{"softsign" , {{"op::softsign(%)", "_0"}}},
{"exp" , {{"op::exp(%)", "_0"}}},
{"expm1" , {{"op::expm1(%)", "_0"}}},
Expand Down Expand Up @@ -135,6 +136,7 @@ const std::map<std::string, std::vector<std::vector<std::string>>> ops_desc = {
{"logical_not" , {{"op::logical_not(%)", "_0"}}},
{"_backward_relu" , {{"op::backward_relu(%, %)", "_0", "_1"}}},
{"_backward_sigmoid" , {{"op::backward_sigmoid(%, %)", "_0", "_1"}}},
{"_backward_log_sigmoid" , {{"op::backward_log_sigmoid(%, %)", "_0", "_1"}}},
{"_backward_expm1" , {{"op::backward_expm1(%, %)", "_0", "_1"}}},
{"_backward_log" , {{"op::backward_log(%, %)", "_0", "_1"}}},
{"_backward_log10" , {{"op::backward_log10(%, %)", "_0", "_1"}}},
Expand Down
4 changes: 4 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,10 @@ MXNET_UNARY_MATH_OP(sigmoid, 1.0f / (1.0f + math::exp(-a)));

MXNET_UNARY_MATH_OP(sigmoid_grad, math::id(a) * (1.0f - math::id(a)));

MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + math::exp(-a))));

MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a)));

MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));

MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a)));
Expand Down
13 changes: 12 additions & 1 deletion src/operator/nn/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace activation {
enum ActivationOpInputs {kData};
enum ActivationOpOutputs {kOut};
enum ActivationOpResource {kTempSpace};
enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU, kSoftSign};
enum ActivationOpType {kReLU, kSigmoid, kLogSigmoid, kTanh, kSoftReLU, kSoftSign};

// Get the number of inputs to the gradient depending on the activation type
int GradNumInputs(int act_type);
Expand All @@ -60,6 +60,7 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> {
DMLC_DECLARE_FIELD(act_type)
.add_enum("relu", activation::kReLU)
.add_enum("sigmoid", activation::kSigmoid)
.add_enum("log_sigmoid", activation::kLogSigmoid)
.add_enum("tanh", activation::kTanh)
.add_enum("softrelu", activation::kSoftReLU)
.add_enum("softsign", activation::kSoftSign)
Expand All @@ -75,6 +76,8 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> {
return "relu";
case activation::kSigmoid:
return "sigmoid";
case activation::kLogSigmoid:
return "log_sigmoid";
case activation::kTanh:
return "tanh";
case activation::kSoftReLU:
Expand Down Expand Up @@ -159,6 +162,10 @@ void ActivationComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
ActivationForward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kLogSigmoid:
ActivationForward<xpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kTanh:
ActivationForward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
ctx, inputs[0], req[0], outputs[0]);
Expand Down Expand Up @@ -190,6 +197,10 @@ void ActivationGradComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ct
ActivationBackward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kLogSigmoid:
ActivationBackward<xpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kTanh:
ActivationBackward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
Expand Down
3 changes: 3 additions & 0 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ int GradNumInputs(int act_type) {
case kSoftSign:
case kTanh:
case kSigmoid:
case kLogSigmoid:
return 3;
default:
CHECK(false) << "missing activation type";
Expand Down Expand Up @@ -91,6 +92,7 @@ struct ActivationGrad {
case kSoftSign:
case kTanh:
case kSigmoid:
case kLogSigmoid:
heads.push_back(n->inputs[activation::kData]);
break;
default:
Expand Down Expand Up @@ -168,6 +170,7 @@ The following activation functions are supported:

- `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`
- `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}`
- `log_sigmoid`: :math:`y = log(\frac{1}{1 + exp(-x)})`
- `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}`
- `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`
- `softsign`: :math:`y = \frac{x}{1 + abs(x)}`
Expand Down
3 changes: 3 additions & 0 deletions src/operator/nn/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
} else if (act_type == activation::kSigmoid) {
ActivationBackward<gpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else if (act_type == activation::kLogSigmoid) {
ActivationBackward<gpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else {
LOG(FATAL) << "unknown activation type";
}
Expand Down
3 changes: 3 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace op {
bool SupportMKLDNNAct(const ActivationParam& param) {
return param.act_type == activation::kReLU
|| param.act_type == activation::kSigmoid
|| param.act_type == activation::kLogSigmoid
|| param.act_type == activation::kSoftReLU
|| param.act_type == activation::kTanh;
}
Expand Down Expand Up @@ -83,6 +84,8 @@ mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
return mkldnn::algorithm::eltwise_relu;
case activation::kSigmoid:
return mkldnn::algorithm::eltwise_logistic;
case activation::kLogSigmoid:
return mkldnn::algorithm::eltwise_logsigmoid;
case activation::kTanh:
return mkldnn::algorithm::eltwise_tanh;
case activation::kSoftReLU:
Expand Down
2 changes: 2 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log_sigmoid); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_sigmoid_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softsign); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softsign_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu); // NOLINT()
Expand Down
17 changes: 17 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,23 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid,
return ret;
});

// log_sigmoid
MXNET_OPERATOR_REGISTER_UNARY(log_sigmoid)
MXNET_ADD_SPARSE_OP_ALIAS(log_sigmoid)
.describe(R"code(Computes log_sigmoid of x element-wise.

.. math::
y = log(1 / (1 + exp(-x)))

The storage type of ``log_sigmoid`` output is always dense

)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::log_sigmoid>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_log_sigmoid"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_log_sigmoid,
unary_bwd<mshadow_op::log_sigmoid_grad>);



DMLC_REGISTER_PARAMETER(HardSigmoidParam);
Expand Down
6 changes: 6 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ NNVM_REGISTER_OP(sigmoid)
NNVM_REGISTER_OP(_backward_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryRTCCompute{"backward_sigmoid"});

NNVM_REGISTER_OP(log_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", UnaryRTCCompute{"log_sigmoid"});

NNVM_REGISTER_OP(_backward_log_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryRTCCompute{"backward_log_sigmoid"});

NNVM_REGISTER_OP(hard_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", HardSigmoidForward<gpu>);

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/operator/activation_perf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ TEST(ACTIVATION_PERF, ExecuteBidirectional) {
vector<string> activations = {
"relu",
"sigmoid",
"log_sigmoid",
"tanh",
"softrelu",
"softsign"
Expand Down
4 changes: 4 additions & 0 deletions tests/python/mkl/subgraphs/test_conv_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def hybrid_forward(self, F, x):
@pytest.mark.parametrize('alg,quantize', [
("relu", False), #TODO(bgawrych): investigate
("sigmoid", True),
("log_sigmoid", False),
("tanh", False), #TODO(bgawrych): investigate
#("softrelu", True), #TODO(bgawrych): bug in oneDNN with AVX
("relu6", False), #TODO(bgawrych): investigate
Expand Down Expand Up @@ -147,6 +148,7 @@ def hybrid_forward(self, F, x):
@pytest.mark.parametrize('alg,quantize', [
("relu", True),
("sigmoid", True),
("log_sigmoid", True),
("tanh", True),
("softrelu", True),
("relu6", True),
Expand Down Expand Up @@ -183,6 +185,7 @@ def hybrid_forward(self, F, x):
@pytest.mark.parametrize('alg,quantize', [
("relu", True),
("sigmoid", True),
("log_sigmoid", True),
("tanh", True),
#("softrelu", True), #TODO(bgawrych): failing fusion check - difference in random single element
("relu6", True),
Expand Down Expand Up @@ -289,6 +292,7 @@ def hybrid_forward(self, F, x, shared_weight):
@pytest.mark.parametrize('alg,quantize', [
("relu", True),
("sigmoid", True),
("log_sigmoid", True),
("tanh", True),
("softrelu", True),
("relu6", True),
Expand Down
4 changes: 2 additions & 2 deletions tests/python/mkl/subgraphs/test_fc_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from mxnet.gluon import nn
from mxnet.test_utils import assert_almost_equal_with_err

fc_post_ops_list=['relu', 'sigmoid', 'tanh', 'softrelu', 'gelu', 'elu', 'leaky',
fc_post_ops_list=['relu', 'sigmoid', 'log_sigmoid', 'tanh', 'softrelu', 'gelu', 'elu', 'leaky',
'square', 'square_root', 'abs', 'exp', 'bounded_relu']

def test_float64_fallback():
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, use_bias, flatten, alg, **kwargs):

def hybrid_forward(self, F, x):
fc_out = self.fc(x)
if self.alg in ['relu', 'sigmoid', 'tanh', 'softrelu']:
if self.alg in ['relu', 'sigmoid', 'log_sigmoid', 'tanh', 'softrelu']:
out = F.Activation(fc_out, act_type=self.alg)
elif self.alg in ['gelu', 'elu', 'leaky']:
out = F.LeakyReLU(fc_out, act_type=self.alg)
Expand Down
17 changes: 16 additions & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,21 @@ def fsigmoid(a):
check_symbolic_forward(y, [xa], [ya])
check_symbolic_backward(y, [xa], [np.ones(shape)], [ya * (1 - ya)])

def test_log_sigmoid():
def flog_sigmoid(a):
return np.log(np.divide(1.0, np.add(1.0, np.exp(-a))))
def flog_sigmoid_grad(a):
return np.divide(1.0, np.add(1.0, np.exp(a)))
shape = (3, 4)
x = mx.symbol.Variable("x")
y = mx.sym.log_sigmoid(x)
xa = np.random.uniform(low=-1.0,high=1.0,size=shape)
ya = flog_sigmoid(xa)
ya_grad = flog_sigmoid_grad(xa)
check_numeric_gradient(y, [xa], numeric_eps=1E-3)
check_symbolic_forward(y, [xa], [ya])
check_symbolic_backward(y, [xa], [np.ones(shape)], [ya_grad])

def test_shape_array():
for i in range(1,6):
shape = rand_shape_nd(i)
Expand Down Expand Up @@ -8697,7 +8712,7 @@ def test_get_operator_arguments():
assert isinstance(operator_arguments, OperatorArguments)
assert operator_arguments.names == ['data', 'act_type']
assert operator_arguments.types \
== ['NDArray-or-Symbol', "{'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required"]
== ['NDArray-or-Symbol', "{'log_sigmoid', 'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required"]
assert operator_arguments.narg == 2


Expand Down