From 54a459c1e91ba29089007b1a8fc46002ac35ff0e Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 1 Apr 2021 11:58:51 -0700 Subject: [PATCH 1/6] fix softmax --- python/mxnet/ndarray/numpy_extension/_op.py | 31 ++--- python/mxnet/numpy_extension/_op.py | 12 +- .../numpy_extension/npx_softmax_op.cc | 118 +++++++++++++++++- src/operator/nn/softmax-inl.h | 15 +++ 4 files changed, 141 insertions(+), 35 deletions(-) diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index 718022dc5b8c..2935d65f2c4c 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -134,7 +134,7 @@ def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False, # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): +def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normalize=True): r"""Applies the softmax function masking elements according to the mask provided Parameters @@ -167,22 +167,15 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): >>> data = np.arange(10).reshape((2, 5)) >>> npx.masked_softmax(data, mask, axis=0) array([[0.00669285, 0. , 0.00669285, 0. , 0.00669285], - [0.9933072 , 0. , 0.9933072 , 0. , 0.9933072 ]]) + [0.9933072 , 0. , 0.9933072 , 0. , 0.9933072 ]]) """ - if mask is not None: - neg = -1e18 - if _np.dtype(dtype) == _np.float16: - neg = -1e4 - data = np.where(mask, data, neg) - logits = (softmax(data, axis=axis) / temperature) * mask - else: - logits = softmax(data, axis=axis) / temperature - return logits + assert data is not None and mask is not None, "Missing input data and mask" + return _api_internal.masked_softmax(data, mask, axis, temperature, dtype, normalize) # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): +def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normalize=True): r"""Computes the masked log softmax of the input. This is equivalent to computing masked softmax followed by log. @@ -216,18 +209,10 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): >>> data = np.arange(10).reshape((2, 5)) >>> npx.masked_log_softmax(data, mask, axis=0) array([[-5.0067153 , -inf, -5.0067153 , -inf, -5.0067153 ], - [-0.00671535, -inf, -0.00671535, -inf, -0.00671535]]) + [-0.00671535, -inf, -0.00671535, -inf, -0.00671535]]) """ - if mask is not None: - neg = -1e18 - inf = -_np.inf - if _np.dtype(dtype) == _np.float16: - neg = -1e4 - data = np.where(mask, data, neg) - logits = np.where(mask, log_softmax(data, axis=axis) / temperature, inf) - else: - logits = log_softmax(data, axis=axis) / temperature - return logits + assert data is not None and mask is not None, "Missing input data and mask" + return _api_internal.masked_log_softmax(data, mask, axis, temperature, dtype, normalize) # pylint: disable=too-many-arguments, unused-argument diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index b7d75ffdc6d0..d5aa6a0090b3 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -118,7 +118,7 @@ def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False, # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): +def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normalize=True): r"""Applies the softmax function masking elements according to the mask provided Parameters @@ -151,15 +151,15 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): >>> data = np.arange(10).reshape((2, 5)) >>> npx.masked_softmax(data, mask, axis=0) array([[0.00669285, 0. , 0.00669285, 0. , 0.00669285], - [0.9933072 , 0. , 0.9933072 , 0. , 0.9933072 ]]) + [0.9933072 , 0. , 0.9933072 , 0. , 0.9933072 ]]) """ return _mx_nd_npx.masked_softmax(data, mask, axis=axis, temperature=temperature, - dtype=dtype) + dtype=dtype, normalize=normalize) # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): +def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normalize=True): r"""Computes the masked log softmax of the input. This is equivalent to computing masked softmax followed by log. @@ -193,10 +193,10 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): >>> data = np.arange(10).reshape((2, 5)) >>> npx.masked_log_softmax(data, mask, axis=0) array([[-5.0067153 , -inf, -5.0067153 , -inf, -5.0067153 ], - [-0.00671535, -inf, -0.00671535, -inf, -0.00671535]]) + [-0.00671535, -inf, -0.00671535, -inf, -0.00671535]]) """ return _mx_nd_npx.masked_log_softmax(data, mask, axis=axis, temperature=temperature, - dtype=dtype) + dtype=dtype, normalize=normalize) # pylint: disable=too-many-arguments, unused-argument diff --git a/src/api/operator/numpy_extension/npx_softmax_op.cc b/src/api/operator/numpy_extension/npx_softmax_op.cc index 641129e00ae9..72ca4a12e74b 100644 --- a/src/api/operator/numpy_extension/npx_softmax_op.cc +++ b/src/api/operator/numpy_extension/npx_softmax_op.cc @@ -51,9 +51,11 @@ MXNET_REGISTER_API("_npx.softmax") } // parse axis - if (args[args_size - 4].type_code() == kDLInt) { + if (args[args_size - 4].type_code() == kNull) { + param.axis = -1; + } else if (args[args_size - 4].type_code() == kDLInt) { param.axis = args[args_size - 4].operator int(); - } else { + } else if (args[args_size - 4].type_code() == kDLFloat) { param.axis = static_cast(args[args_size - 4].operator double()); } @@ -61,7 +63,7 @@ MXNET_REGISTER_API("_npx.softmax") if (args[args_size - 3].type_code() == kNull) { param.temperature = dmlc::nullopt; } else { - param.temperature = args[args_size - 3].operator int64_t(); + param.temperature = args[args_size - 3].operator double(); } // parse dtype @@ -104,9 +106,11 @@ MXNET_REGISTER_API("_npx.log_softmax") } // parse axis - if (args[args_size - 4].type_code() == kDLInt) { + if (args[args_size - 4].type_code() == kNull) { + param.axis = -1; + } else if (args[args_size - 4].type_code() == kDLInt) { param.axis = args[args_size - 4].operator int(); - } else { + } else if (args[args_size - 4].type_code() == kDLFloat) { param.axis = static_cast(args[args_size - 4].operator double()); } @@ -114,7 +118,7 @@ MXNET_REGISTER_API("_npx.log_softmax") if (args[args_size - 3].type_code() == kNull) { param.temperature = dmlc::nullopt; } else { - param.temperature = args[args_size - 3].operator int64_t(); + param.temperature = args[args_size - 3].operator double(); } // parse dtype @@ -133,4 +137,106 @@ MXNET_REGISTER_API("_npx.log_softmax") *ret = ndoutputs[0]; }); +MXNET_REGISTER_API("_npx.masked_softmax") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + static const nnvm::Op* op = Op::Get("_npx_masked_softmax"); + op::MaskedSoftmaxParam param; + + int args_size = args.size(); + // inputs + int num_inputs = 2; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // parse axis + if (args[2].type_code() == kNull) { + param.axis = -1; + } else if (args[2].type_code() == kDLInt) { + param.axis = args[2].operator int(); + } else if (args[2].type_code() == kDLFloat) { + param.axis = static_cast(args[2].operator double()); + } + // parse temperature + if (args[3].type_code() == kNull) { + param.temperature = dmlc::nullopt; + } else { + param.temperature = args[3].operator double(); + } + // parse dtype + if (args[4].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); + } + // parse normalize + if (args[5].type_code() == kNull) { + param.normalize = true; + } else { + param.normalize = args[5].operator bool(); + } + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +MXNET_REGISTER_API("_npx.masked_log_softmax") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + static const nnvm::Op* op = Op::Get("_npx_masked_log_softmax"); + op::MaskedSoftmaxParam param; + + int args_size = args.size(); + // inputs + int num_inputs = 2; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // parse axis + if (args[2].type_code() == kNull) { + param.axis = -1; + } else if (args[2].type_code() == kDLInt) { + param.axis = args[2].operator int(); + } else if (args[2].type_code() == kDLFloat) { + param.axis = static_cast(args[2].operator double()); + } + // parse temperature + if (args[3].type_code() == kNull) { + param.temperature = dmlc::nullopt; + } else { + param.temperature = args[3].operator double(); + } + // parse dtype + if (args[4].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); + } + // parse normalize + if (args[5].type_code() == kNull) { + param.normalize = true; + } else { + param.normalize = args[5].operator bool(); + } + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + } // namespace mxnet diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 7f64b7426c3f..35c442e7d599 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -1210,6 +1210,21 @@ struct MaskedSoftmaxParam : public dmlc::Parameter { .set_default(dmlc::optional(true)) .describe("Whether to normalize input data x: x = x - max(x)"); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream axis_s, temperature_s, dtype_s, normalize_s; + axis_s << axis; + temperature_s << temperature; + dtype_s << dtype; + normalize_s << normalize; + (*dict)["axis"] = axis_s.str(); + (*dict)["temperature"] = temperature_s.str(); + if (dtype.has_value()) { + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype.value()); + } else { + (*dict)["dtype"] = dtype_s.str(); + } + (*dict)["normalize"] = normalize_s.str(); + } }; static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) { From 6d03f656013445fbd0a12e31d42542906d4a89a7 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 1 Apr 2021 12:03:38 -0700 Subject: [PATCH 2/6] remove import np --- python/mxnet/ndarray/numpy_extension/_op.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index 2935d65f2c4c..95f64b74b4ec 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -19,7 +19,6 @@ used in Gluon dispatched by F=ndarray module.""" import numpy as _np -from .. import numpy as np # pylint: disable=reimported from .._internal import NDArrayBase from . import _api_internal from ...util import set_module From c62efbb39d27429504a899a65e44f75f73f8d964 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 1 Apr 2021 13:25:17 -0700 Subject: [PATCH 3/6] add test cases --- .../numpy_extension/npx_softmax_op.cc | 14 ++-- tests/python/unittest/test_numpy_op.py | 82 ++++++++++++++++--- 2 files changed, 77 insertions(+), 19 deletions(-) diff --git a/src/api/operator/numpy_extension/npx_softmax_op.cc b/src/api/operator/numpy_extension/npx_softmax_op.cc index 72ca4a12e74b..48157f23f0b1 100644 --- a/src/api/operator/numpy_extension/npx_softmax_op.cc +++ b/src/api/operator/numpy_extension/npx_softmax_op.cc @@ -144,7 +144,6 @@ MXNET_REGISTER_API("_npx.masked_softmax") static const nnvm::Op* op = Op::Get("_npx_masked_softmax"); op::MaskedSoftmaxParam param; - int args_size = args.size(); // inputs int num_inputs = 2; std::vector inputs; @@ -153,12 +152,12 @@ MXNET_REGISTER_API("_npx.masked_softmax") inputs.push_back(args[i].operator mxnet::NDArray*()); } // parse axis - if (args[2].type_code() == kNull) { - param.axis = -1; - } else if (args[2].type_code() == kDLInt) { + if (args[2].type_code() == kDLInt) { param.axis = args[2].operator int(); } else if (args[2].type_code() == kDLFloat) { param.axis = static_cast(args[2].operator double()); + } else { + param.axis = -1; } // parse temperature if (args[3].type_code() == kNull) { @@ -195,7 +194,6 @@ MXNET_REGISTER_API("_npx.masked_log_softmax") static const nnvm::Op* op = Op::Get("_npx_masked_log_softmax"); op::MaskedSoftmaxParam param; - int args_size = args.size(); // inputs int num_inputs = 2; std::vector inputs; @@ -204,12 +202,12 @@ MXNET_REGISTER_API("_npx.masked_log_softmax") inputs.push_back(args[i].operator mxnet::NDArray*()); } // parse axis - if (args[2].type_code() == kNull) { - param.axis = -1; - } else if (args[2].type_code() == kDLInt) { + if (args[2].type_code() == kDLInt) { param.axis = args[2].operator int(); } else if (args[2].type_code() == kDLFloat) { param.axis = static_cast(args[2].operator double()); + } else { + param.axis = -1; } // parse temperature if (args[3].type_code() == kNull) { diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 6bea5109b4c6..cf776f18fae7 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1931,6 +1931,18 @@ def _test_batchnorm_impl(axis, _test_batchnorm_impl(axis, data_grad_req, gamma_grad_req, beta_grad_req) + +def np_softmax(x, axis=-1): + if (x.shape[axis] == 0): + return _np.sum(x, axis=axis, keepdims=True) + x = x - _np.max(x, axis=axis, keepdims=True) + x = _np.exp(x) + x /= _np.sum(x, axis=axis, keepdims=True) + return x + +def np_log_softmax(x, axis=-1): + return _np.log(np_softmax(x, axis)) + @use_np def test_npx_softmax(): class TestSoftmax(HybridBlock): @@ -1949,17 +1961,6 @@ def __init__(self, axis): def hybrid_forward(self, F, a): return F.npx.log_softmax(a, axis=axis) - def np_softmax(x, axis=-1): - if (x.shape[axis] == 0): - return _np.sum(x, axis=axis, keepdims=True) - x = x - _np.max(x, axis=axis, keepdims=True) - x = _np.exp(x) - x /= _np.sum(x, axis=axis, keepdims=True) - return x - - def np_log_softmax(x, axis=-1): - return _np.log(np_softmax(x, axis)) - #(operator, function) tuples tested_ops = [(TestSoftmax, np_softmax), (TestLogSoftmax, np_log_softmax)] @@ -1988,6 +1989,65 @@ def np_log_softmax(x, axis=-1): assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5) +def np_masked_softmax(data, mask, axis=-1, temperature=1.0): + neg = -1e18 + if data.dtype == _np.float16: + neg = -1e4 + temp = _np.where(mask, data, neg) + result = (np_softmax(temp, axis=axis) / temperature) * mask + return result + +def np_masked_log_softmax(data, mask, axis=-1, temperature=1.0): + return _np.log(np_masked_softmax(data, mask, axis, temperature)+1e-20) * mask + +@use_np +@pytest.mark.parametrize('hybridize', [True, False]) +@pytest.mark.parametrize('shape', [(3, 0, 4), (0, 0)]) +@pytest.mark.parametrize('temperature', [1.0, 2.0, 3.0]) +def test_npx_masked_softmax(hybridize, shape, temperature): + class TestMaskedSoftmax(HybridBlock): + def __init__(self, axis, temperature): + super(TestMaskedSoftmax, self).__init__() + self._axis = axis + self._temperature = temperature + + def hybrid_forward(self, F, a, mask): + return F.npx.masked_softmax(a, mask, axis=self._axis, temperature=self._temperature) + + class TestMaskedLogSoftmax(HybridBlock): + def __init__(self, axis, temperature): + super(TestMaskedLogSoftmax, self).__init__() + self._axis = axis + self._temperature = temperature + + def hybrid_forward(self, F, a, mask): + return F.npx.masked_log_softmax(a, mask, axis=self._axis, temperature=self._temperature) + + #(operator, function) tuples + tested_ops = [(TestMaskedSoftmax, np_masked_softmax), + (TestMaskedLogSoftmax, np_masked_log_softmax)] + + # only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py + for SoftmaxOp, softmax_function in tested_ops: + mx_a = np.random.uniform(size=shape) + mask = np.random.randint(0, 2, shape) + mx_a.attach_grad() + mask.attach_grad() + for axis in range(-len(shape), len(shape)): + test_softmax_op = SoftmaxOp(axis, temperature) + if hybridize: + test_softmax_op.hybridize() + + with mx.autograd.record(): + mx_out = test_softmax_op(mx_a, mask) + + np_out = softmax_function(mx_a.asnumpy(), mask.asnumpy(), axis, temperature) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True) + + mx_out.backward() + assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5) + + @use_np def test_npi_boolean_assign(): class TestBooleanAssignScalar(HybridBlock): From c1d77b323e8f98222963deb8ba826f305e14ff29 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 1 Apr 2021 14:00:14 -0700 Subject: [PATCH 4/6] softmax axis default to -1 --- src/api/operator/numpy_extension/npx_softmax_op.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/api/operator/numpy_extension/npx_softmax_op.cc b/src/api/operator/numpy_extension/npx_softmax_op.cc index 48157f23f0b1..f6e048cae4c6 100644 --- a/src/api/operator/numpy_extension/npx_softmax_op.cc +++ b/src/api/operator/numpy_extension/npx_softmax_op.cc @@ -51,12 +51,12 @@ MXNET_REGISTER_API("_npx.softmax") } // parse axis - if (args[args_size - 4].type_code() == kNull) { - param.axis = -1; - } else if (args[args_size - 4].type_code() == kDLInt) { + if (args[args_size - 4].type_code() == kDLInt) { param.axis = args[args_size - 4].operator int(); } else if (args[args_size - 4].type_code() == kDLFloat) { param.axis = static_cast(args[args_size - 4].operator double()); + } else { + param.axis = -1; } // parse temperature @@ -106,12 +106,12 @@ MXNET_REGISTER_API("_npx.log_softmax") } // parse axis - if (args[args_size - 4].type_code() == kNull) { - param.axis = -1; - } else if (args[args_size - 4].type_code() == kDLInt) { + if (args[args_size - 4].type_code() == kDLInt) { param.axis = args[args_size - 4].operator int(); } else if (args[args_size - 4].type_code() == kDLFloat) { param.axis = static_cast(args[args_size - 4].operator double()); + } else { + param.axis = -1; } // parse temperature From c1e302dc2e12ab1d15fb7efd1c31bb8d0729809c Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 1 Apr 2021 15:45:18 -0700 Subject: [PATCH 5/6] remove dtype --- python/mxnet/ndarray/numpy_extension/_op.py | 14 ++++--------- python/mxnet/numpy_extension/_op.py | 14 ++++--------- .../numpy_extension/npx_softmax_op.cc | 20 ++++--------------- src/operator/nn/softmax-inl.h | 9 +-------- 4 files changed, 13 insertions(+), 44 deletions(-) diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index 95f64b74b4ec..1ff9c2d1ff6d 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -133,7 +133,7 @@ def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False, # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normalize=True): +def masked_softmax(data, mask, axis=-1, temperature=1.0, normalize=True): r"""Applies the softmax function masking elements according to the mask provided Parameters @@ -146,9 +146,6 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normalize=T The axis along which to compute softmax. temperature : double or None, optional, default=None Temperature parameter in softmax - dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' - DType of the output in case this can't be inferred. Defaults to - the same as input's dtype if not defined (dtype=None). normalize : boolean or None, optional, default=1 Whether to normalize input data x: x = x - max(x) @@ -169,12 +166,12 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normalize=T [0.9933072 , 0. , 0.9933072 , 0. , 0.9933072 ]]) """ assert data is not None and mask is not None, "Missing input data and mask" - return _api_internal.masked_softmax(data, mask, axis, temperature, dtype, normalize) + return _api_internal.masked_softmax(data, mask, axis, temperature, normalize) # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normalize=True): +def masked_log_softmax(data, mask, axis=-1, temperature=1.0, normalize=True): r"""Computes the masked log softmax of the input. This is equivalent to computing masked softmax followed by log. @@ -188,9 +185,6 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normali The axis along which to compute softmax. temperature : double or None, optional, default=None Temperature parameter in softmax - dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' - DType of the output in case this can't be inferred. Defaults to - the same as input's dtype if not defined (dtype=None). normalize : boolean or None, optional, default=1 Whether to normalize input data x: x = x - max(x) @@ -211,7 +205,7 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normali [-0.00671535, -inf, -0.00671535, -inf, -0.00671535]]) """ assert data is not None and mask is not None, "Missing input data and mask" - return _api_internal.masked_log_softmax(data, mask, axis, temperature, dtype, normalize) + return _api_internal.masked_log_softmax(data, mask, axis, temperature, normalize) # pylint: disable=too-many-arguments, unused-argument diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index d5aa6a0090b3..124eb00cb76b 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -118,7 +118,7 @@ def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False, # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normalize=True): +def masked_softmax(data, mask, axis=-1, temperature=1.0, normalize=True): r"""Applies the softmax function masking elements according to the mask provided Parameters @@ -131,9 +131,6 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normalize=T The axis along which to compute softmax. temperature : double or None, optional, default=None Temperature parameter in softmax - dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' - DType of the output in case this can't be inferred. Defaults to - the same as input's dtype if not defined (dtype=None). normalize : boolean or None, optional, default=1 Whether to normalize input data x: x = x - max(x) @@ -154,12 +151,12 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normalize=T [0.9933072 , 0. , 0.9933072 , 0. , 0.9933072 ]]) """ return _mx_nd_npx.masked_softmax(data, mask, axis=axis, temperature=temperature, - dtype=dtype, normalize=normalize) + normalize=normalize) # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normalize=True): +def masked_log_softmax(data, mask, axis=-1, temperature=1.0, normalize=True): r"""Computes the masked log softmax of the input. This is equivalent to computing masked softmax followed by log. @@ -173,9 +170,6 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normali The axis along which to compute softmax. temperature : double or None, optional, default=None Temperature parameter in softmax - dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' - DType of the output in case this can't be inferred. Defaults to - the same as input's dtype if not defined (dtype=None). normalize : boolean or None, optional, default=1 Whether to normalize input data x: x = x - max(x) @@ -196,7 +190,7 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None, normali [-0.00671535, -inf, -0.00671535, -inf, -0.00671535]]) """ return _mx_nd_npx.masked_log_softmax(data, mask, axis=axis, temperature=temperature, - dtype=dtype, normalize=normalize) + normalize=normalize) # pylint: disable=too-many-arguments, unused-argument diff --git a/src/api/operator/numpy_extension/npx_softmax_op.cc b/src/api/operator/numpy_extension/npx_softmax_op.cc index f6e048cae4c6..6e934ed4a64f 100644 --- a/src/api/operator/numpy_extension/npx_softmax_op.cc +++ b/src/api/operator/numpy_extension/npx_softmax_op.cc @@ -165,17 +165,11 @@ MXNET_REGISTER_API("_npx.masked_softmax") } else { param.temperature = args[3].operator double(); } - // parse dtype - if (args[4].type_code() == kNull) { - param.dtype = dmlc::nullopt; - } else { - param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); - } // parse normalize - if (args[5].type_code() == kNull) { + if (args[4].type_code() == kNull) { param.normalize = true; } else { - param.normalize = args[5].operator bool(); + param.normalize = args[4].operator bool(); } attrs.parsed = param; @@ -215,17 +209,11 @@ MXNET_REGISTER_API("_npx.masked_log_softmax") } else { param.temperature = args[3].operator double(); } - // parse dtype - if (args[4].type_code() == kNull) { - param.dtype = dmlc::nullopt; - } else { - param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); - } // parse normalize - if (args[5].type_code() == kNull) { + if (args[4].type_code() == kNull) { param.normalize = true; } else { - param.normalize = args[5].operator bool(); + param.normalize = args[4].operator bool(); } attrs.parsed = param; diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 35c442e7d599..3f037f9c3849 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -1199,7 +1199,6 @@ struct SoftmaxParam : public dmlc::Parameter { struct MaskedSoftmaxParam : public dmlc::Parameter { int axis; dmlc::optional temperature; - dmlc::optional dtype; dmlc::optional normalize; DMLC_DECLARE_PARAMETER(MaskedSoftmaxParam) { DMLC_DECLARE_FIELD(axis).set_default(-1) @@ -1211,18 +1210,12 @@ struct MaskedSoftmaxParam : public dmlc::Parameter { .describe("Whether to normalize input data x: x = x - max(x)"); } void SetAttrDict(std::unordered_map* dict) { - std::ostringstream axis_s, temperature_s, dtype_s, normalize_s; + std::ostringstream axis_s, temperature_s, normalize_s; axis_s << axis; temperature_s << temperature; - dtype_s << dtype; normalize_s << normalize; (*dict)["axis"] = axis_s.str(); (*dict)["temperature"] = temperature_s.str(); - if (dtype.has_value()) { - (*dict)["dtype"] = MXNetTypeWithBool2String(dtype.value()); - } else { - (*dict)["dtype"] = dtype_s.str(); - } (*dict)["normalize"] = normalize_s.str(); } }; From 064faa3669ece92bc60d20d717e5cb3582dba444 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 1 Apr 2021 18:20:25 -0700 Subject: [PATCH 6/6] fix softmax test cases --- tests/python/unittest/test_numpy_op.py | 28 +++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index cf776f18fae7..83ef73a75c27 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1998,30 +1998,31 @@ def np_masked_softmax(data, mask, axis=-1, temperature=1.0): return result def np_masked_log_softmax(data, mask, axis=-1, temperature=1.0): - return _np.log(np_masked_softmax(data, mask, axis, temperature)+1e-20) * mask + neg = -1e18 + if data.dtype == _np.float16: + neg = -1e4 + data = _np.where(mask, data, neg) + return _np.where(mask, np_log_softmax(data, axis=axis) / temperature, -_np.inf) @use_np @pytest.mark.parametrize('hybridize', [True, False]) @pytest.mark.parametrize('shape', [(3, 0, 4), (0, 0)]) -@pytest.mark.parametrize('temperature', [1.0, 2.0, 3.0]) -def test_npx_masked_softmax(hybridize, shape, temperature): +def test_npx_masked_softmax(hybridize, shape): class TestMaskedSoftmax(HybridBlock): - def __init__(self, axis, temperature): + def __init__(self, axis): super(TestMaskedSoftmax, self).__init__() self._axis = axis - self._temperature = temperature def hybrid_forward(self, F, a, mask): - return F.npx.masked_softmax(a, mask, axis=self._axis, temperature=self._temperature) + return F.npx.masked_softmax(a, mask, axis=self._axis) class TestMaskedLogSoftmax(HybridBlock): - def __init__(self, axis, temperature): + def __init__(self, axis): super(TestMaskedLogSoftmax, self).__init__() self._axis = axis - self._temperature = temperature def hybrid_forward(self, F, a, mask): - return F.npx.masked_log_softmax(a, mask, axis=self._axis, temperature=self._temperature) + return F.npx.masked_log_softmax(a, mask, axis=self._axis) #(operator, function) tuples tested_ops = [(TestMaskedSoftmax, np_masked_softmax), @@ -2034,18 +2035,17 @@ def hybrid_forward(self, F, a, mask): mx_a.attach_grad() mask.attach_grad() for axis in range(-len(shape), len(shape)): - test_softmax_op = SoftmaxOp(axis, temperature) + test_softmax_op = SoftmaxOp(axis) if hybridize: test_softmax_op.hybridize() with mx.autograd.record(): mx_out = test_softmax_op(mx_a, mask) - np_out = softmax_function(mx_a.asnumpy(), mask.asnumpy(), axis, temperature) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True) + mx_out.wait_to_read() - mx_out.backward() - assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5) + np_out = softmax_function(mx_a.asnumpy(), mask.asnumpy(), axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True) @use_np