From 9e2a496642d262f2dbe539941ed5f6532a8f2eca Mon Sep 17 00:00:00 2001 From: barry-jin Date: Tue, 23 Mar 2021 15:33:52 -0700 Subject: [PATCH 01/13] ffi: softmx, activattion, batch_norm, fully_connected --- python/mxnet/base.py | 5 +- .../ndarray/numpy_extension/_api_internal.py | 24 ++ python/mxnet/ndarray/numpy_extension/_op.py | 388 +++++++++++++++++- python/mxnet/numpy_extension/_op.py | 357 +++++++++++++++- .../numpy_extension/npx_activation_op.cc | 68 +++ .../numpy_extension/npx_batch_norm_op.cc | 87 ++++ .../numpy_extension/npx_fully_connected_op.cc | 66 +++ .../numpy_extension/npx_softmax_op.cc | 136 ++++++ src/operator/nn/activation-inl.h | 5 + src/operator/nn/batch_norm-inl.h | 22 + src/operator/nn/fully_connected-inl.h | 9 + src/operator/nn/softmax-inl.h | 15 + 12 files changed, 1179 insertions(+), 3 deletions(-) create mode 100644 python/mxnet/ndarray/numpy_extension/_api_internal.py create mode 100644 src/api/operator/numpy_extension/npx_activation_op.cc create mode 100644 src/api/operator/numpy_extension/npx_batch_norm_op.cc create mode 100644 src/api/operator/numpy_extension/npx_fully_connected_op.cc create mode 100644 src/api/operator/numpy_extension/npx_softmax_op.cc diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 15db63e0bff2..fa1302046474 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -794,6 +794,9 @@ def write_all_str(module_file, module_all_list): _NP_EXT_OP_PREFIX = '_npx_' _NP_EXT_OP_SUBMODULE_LIST = ['_image_', '_random_'] +_NP_EXT_OP_IMPLEMENTED_SET = {'_npx_softmax', '_npx_log_softmax', '_npx_masked_softmax', + '_npx_masked_log_softmax', '_npx_activation', + '_npx_batch_norm', '_npx_fully_connected'} _NP_INTERNAL_OP_PREFIX = '_npi_' @@ -855,7 +858,7 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op elif np_module_name == 'numpy_extension': op_name_prefix = _NP_EXT_OP_PREFIX submodule_name_list = _NP_EXT_OP_SUBMODULE_LIST - op_implemented_set = set() + op_implemented_set = _NP_EXT_OP_IMPLEMENTED_SET elif np_module_name == 'numpy._internal': op_name_prefix = _NP_INTERNAL_OP_PREFIX submodule_name_list = [] diff --git a/python/mxnet/ndarray/numpy_extension/_api_internal.py b/python/mxnet/ndarray/numpy_extension/_api_internal.py new file mode 100644 index 000000000000..b7b2216b1f83 --- /dev/null +++ b/python/mxnet/ndarray/numpy_extension/_api_internal.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for numpy_extension api.""" + +from ..._ffi.function import _init_api + +__all__ = [] + +_init_api("_npx", "mxnet.ndarray.numpy_extension._api_internal") diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index 22738a0f1950..3dfbdb42d405 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -18,4 +18,390 @@ """Namespace for the operators not belonging to the official numpy package used in Gluon dispatched by F=ndarray module.""" -__all__ = [] +import numpy as _np +from .. import numpy as np +from . import _api_internal +from ...util import set_module + + +__all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax', + 'activation', 'batch_norm', 'fully_connected'] + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtype=None): + r"""Applies the softmax function. + + The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1. + + .. math:: + softmax(\mathbf{z/t})_j = \frac{e^{z_j/t}}{\sum_{k=1}^K e^{z_k/t}} + + for :math:`j = 1, ..., K` + + t is the temperature parameter in softmax function. By default, t equals 1.0 + + Parameters + ---------- + data : NDArray + The input array. + length : NDArray + The length array. + axis : int, optional, default='-1' + The axis along which to compute softmax. + temperature : double or None, optional, default=None + Temperature parameter in softmax + dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' + DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + use_length : boolean or None, optional, default=0 + Whether to use the length input as a mask over the data input. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Example + ------- + >>> data = np.ones((2, 3)) + >>> npx.softmax(data, axis=0) + array([[0.5, 0.5, 0.5], + [0.5, 0.5, 0.5]]) + >>> npx.softmax(data, axis=1) + array([[0.33333334, 0.33333334, 0.33333334], + [0.33333334, 0.33333334, 0.33333334]]) + """ + if dtype and not isinstance(dtype, str): + dtype = _np.dtype(dtype).name + if use_length: + assert length is not None, "Missing length input" + return _api_internal.softmax(data, length, axis, temperature, True, dtype) + else: + assert length is None, "Length input is not used" + return _api_internal.softmax(data, axis, temperature, False, dtype) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def log_softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtype=None): + r"""Computes the log softmax of the input. + This is equivalent to computing softmax followed by log. + + Parameters + ---------- + data : NDArray + The input array. + axis : int, optional, default='-1' + The axis along which to compute softmax. + temperature : double or None, optional, default=None + Temperature parameter in softmax + dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' + DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + use_length : boolean or None, optional, default=0 + Whether to use the length input as a mask over the data input. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Examples + -------- + >>> data = np.array([1, 2, .1]) + >>> npx.log_softmax(data) + array([-1.4170278, -0.4170278, -2.3170278]) + >>> data = np.array([[1, 2, .1],[.1, 2, 1]]) + >>> npx.log_softmax(data, axis=0) + array([[-0.34115386, -0.6931472 , -1.2411538 ], + [-1.2411538 , -0.6931472 , -0.34115386]]) + """ + if dtype and not isinstance(dtype, str): + dtype = _np.dtype(dtype).name + if use_length: + assert length is not None, "Missing length input" + return _api_internal.log_softmax(data, length, axis, temperature, True, dtype) + else: + assert length is None, "Length input is not used" + return _api_internal.log_softmax(data, axis, temperature, False, dtype) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): + r"""Applies the softmax function masking elements according to the mask provided + + Parameters + ---------- + data : NDArray + The input array. + mask : NDArray + Mask to apply. + axis : int, optional, default='-1' + The axis along which to compute softmax. + temperature : double or None, optional, default=None + Temperature parameter in softmax + dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' + DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + normalize : boolean or None, optional, default=1 + Whether to normalize input data x: x = x - max(x) + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Examples + -------- + >>> data = np.arange(5) + >>> mask = np.array([1, 0, 1, 0, 1]) + >>> npx.masked_softmax(data, mask) + array([0.01587624, 0. , 0.11731042, 0. , 0.8668133 ]) + >>> data = np.arange(10).reshape((2, 5)) + >>> npx.masked_softmax(data, mask, axis=0) + array([[0.00669285, 0. , 0.00669285, 0. , 0.00669285], + [0.9933072 , 0. , 0.9933072 , 0. , 0.9933072 ]]) + """ + if mask is not None: + neg = -1e18 + if _np.dtype(dtype) == _np.float16: + neg = -1e4 + data = np.where(mask, data, neg) + logits = (softmax(data, axis=axis) / temperature) * mask + else: + logits = softmax(data, axis=axis) / temperature + return logits + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): + r"""Computes the masked log softmax of the input. + This is equivalent to computing masked softmax followed by log. + + Parameters + ---------- + data : NDArray + The input array. + mask : NDArray + Mask to apply. + axis : int, optional, default='-1' + The axis along which to compute softmax. + temperature : double or None, optional, default=None + Temperature parameter in softmax + dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' + DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + normalize : boolean or None, optional, default=1 + Whether to normalize input data x: x = x - max(x) + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Examples + -------- + >>> data = np.arange(5) + >>> mask = np.array([1, 0, 1, 0, 1]) + >>> npx.masked_log_softmax(data, mask) + array([-4.1429286 , -inf, -2.1429286 , -inf, -0.14292854]) + >>> data = np.arange(10).reshape((2, 5)) + >>> npx.masked_log_softmax(data, mask, axis=0) + array([[-5.0067153 , -inf, -5.0067153 , -inf, -5.0067153 ], + [-0.00671535, -inf, -0.00671535, -inf, -0.00671535]]) + """ + if mask is not None: + neg = -1e18 + inf = -_np.inf + if _np.dtype(dtype) == _np.float16: + neg = -1e4 + data = np.where(mask, data, neg) + logits = np.where(mask, log_softmax(data, axis=axis) / temperature, inf) + else: + logits = log_softmax(data, axis=axis) / temperature + return logits + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def activation(data, act_type='relu', name='fwd'): + r"""Applies an activation function element-wise to the input. + + The following activation functions are supported: + + - `relu`: Rectified Linear Unit, :math:`y = max(x, 0)` + - `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}` + - `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}` + - `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))` + - `softsign`: :math:`y = \frac{x}{1 + abs(x)}` + + Parameters + ---------- + data : NDArray + The input array. + act_type : {'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required + Activation function to be applied. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _api_internal.activation(data, act_type) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def batch_norm(x, gamma, beta, running_mean, running_var, name='fwd', eps=1e-3, momentum=0.9, + fix_gamma=True, use_global_stats=False, output_mean_var=False, axis=1, + cudnn_off=False, min_calib_range=None, max_calib_range=None): + r"""Batch normalization. + + Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as + well as offset ``beta``. + + Assume the input has more than one dimension and we normalize along axis 1. + We first compute the mean and variance along this axis: + + .. math:: + + data\_mean[i] = mean(data[:,i,:,...]) \\ + data\_var[i] = var(data[:,i,:,...]) + + Then compute the normalized output, which has the same shape as input, as following: + + .. math:: + + out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} * gamma[i] + beta[i] + + Both *mean* and *var* returns a scalar by treating the input as a vector. + + Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` + have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and + the inverse of ``data_var``, which are needed for the backward pass. Note that gradient of these + two outputs are blocked. + + Besides the inputs and the outputs, this operator accepts two auxiliary + states, ``moving_mean`` and ``moving_var``, which are *k*-length + vectors. They are global statistics for the whole dataset, which are updated + by:: + + moving_mean = moving_mean * momentum + data_mean * (1 - momentum) + moving_var = moving_var * momentum + data_var * (1 - momentum) + + If ``use_global_stats`` is set to be true, then ``moving_mean`` and + ``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute + the output. It is often used during inference. + + The parameter ``axis`` specifies which axis of the input shape denotes + the 'channel' (separately normalized groups). The default is 1. Specifying -1 sets the channel + axis to be the last item in the input shape. + + Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true, + then set ``gamma`` to 1 and its gradient to 0. + + .. Note:: + When ``fix_gamma`` is set to True, no sparse support is provided. If ``fix_gamma is`` set to False, + the sparse tensors will fallback. + + Parameters + ---------- + data : NDArray + Input data to batch normalization + gamma : NDArray + gamma array + beta : NDArray + beta array + moving_mean : NDArray + running mean of input + moving_var : NDArray + running variance of input + eps : double, optional, default=0.0010000000474974513 + Epsilon to prevent div 0. Must be no less than CUDNN_BN_MIN_EPSILON defined in cudnn.h when using cudnn (usually 1e-5) + momentum : float, optional, default=0.899999976 + Momentum for moving average + fix_gamma : boolean, optional, default=1 + Fix gamma while training + use_global_stats : boolean, optional, default=0 + Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator. + output_mean_var : boolean, optional, default=0 + Output the mean and inverse std + axis : int, optional, default='1' + Specify which shape axis the channel is specified + cudnn_off : boolean, optional, default=0 + Do not select CUDNN operator, if available + min_calib_range : float or None, optional, default=None + The minimum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output. + max_calib_range : float or None, optional, default=None + The maximum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _api_internal.batch_norm(x, gamma, beta, running_mean, running_var, eps, momentum, + fix_gamma, use_global_stats, output_mean_var, axis, + cudnn_off, min_calib_range, max_calib_range) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def fully_connected(x, weight, name='fwd', bias=None, num_hidden=None, + no_bias=True, flatten=True): + r"""Applies a linear transformation: :math:`Y = XW^T + b`. + + If ``flatten`` is set to be true, then the shapes are: + + - **data**: `(batch_size, x1, x2, ..., xn)` + - **weight**: `(num_hidden, x1 * x2 * ... * xn)` + - **bias**: `(num_hidden,)` + - **out**: `(batch_size, num_hidden)` + + If ``flatten`` is set to be false, then the shapes are: + + - **data**: `(x1, x2, ..., xn, input_dim)` + - **weight**: `(num_hidden, input_dim)` + - **bias**: `(num_hidden,)` + - **out**: `(x1, x2, ..., xn, num_hidden)` + + The learnable parameters include both ``weight`` and ``bias``. + + If ``no_bias`` is set to be true, then the ``bias`` term is ignored. + + .. Note:: + + The sparse support for FullyConnected is limited to forward evaluation with `row_sparse` + weight and bias, where the length of `weight.indices` and `bias.indices` must be equal + to `num_hidden`. This could be useful for model inference with `row_sparse` weights + trained with importance sampling or noise contrastive estimation. + + To compute linear transformation with 'csr' sparse data, sparse.dot is recommended instead + of sparse.FullyConnected. + + Parameters + ---------- + data : NDArray + Input data. + weight : NDArray + Weight matrix. + bias : NDArray + Bias parameter. + num_hidden : int, required + Number of hidden nodes of the output. + no_bias : boolean, optional, default=0 + Whether to disable bias parameter. + flatten : boolean, optional, default=1 + Whether to collapse all but the first axis of the input data tensor. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + assert num_hidden is not None, "Please provide number of hidden layers" + if no_bias: + return _api_internal.fully_connected(x, weight, num_hidden, no_bias, flatten) + else: + assert bias is not None, "Missing bias input" + return _api_internal.fully_connected(x, weight, bias, num_hidden, no_bias, flatten) diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index a995e480221a..f7cc016fd1b8 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -17,4 +17,359 @@ """Namespace for registering numpy_extension ops for imperative programming.""" -__all__ = [] +from ..ndarray import numpy_extension as _mx_nd_npx +from ..util import set_module + + +__all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax', + 'activation', 'batch_norm', 'fully_connected'] + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtype=None): + r"""Applies the softmax function. + + The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1. + + .. math:: + softmax(\mathbf{z/t})_j = \frac{e^{z_j/t}}{\sum_{k=1}^K e^{z_k/t}} + + for :math:`j = 1, ..., K` + + t is the temperature parameter in softmax function. By default, t equals 1.0 + + Parameters + ---------- + data : NDArray + The input array. + length : NDArray + The length array. + axis : int, optional, default='-1' + The axis along which to compute softmax. + temperature : double or None, optional, default=None + Temperature parameter in softmax + dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' + DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + use_length : boolean or None, optional, default=0 + Whether to use the length input as a mask over the data input. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Example + ------- + >>> data = np.ones((2, 3)) + >>> npx.softmax(data, axis=0) + array([[0.5, 0.5, 0.5], + [0.5, 0.5, 0.5]]) + >>> npx.softmax(data, axis=1) + array([[0.33333334, 0.33333334, 0.33333334], + [0.33333334, 0.33333334, 0.33333334]]) + """ + return _mx_nd_npx.softmax(data, length, axis=axis, temperature=temperature, + use_length=use_length, dtype=dtype) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def log_softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtype=None): + r"""Computes the log softmax of the input. + This is equivalent to computing softmax followed by log. + + Parameters + ---------- + data : NDArray + The input array. + axis : int, optional, default='-1' + The axis along which to compute softmax. + temperature : double or None, optional, default=None + Temperature parameter in softmax + dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' + DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + use_length : boolean or None, optional, default=0 + Whether to use the length input as a mask over the data input. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Examples + -------- + >>> data = np.array([1, 2, .1]) + >>> npx.log_softmax(data) + array([-1.4170278, -0.4170278, -2.3170278]) + >>> data = np.array([[1, 2, .1],[.1, 2, 1]]) + >>> npx.log_softmax(data, axis=0) + array([[-0.34115386, -0.6931472 , -1.2411538 ], + [-1.2411538 , -0.6931472 , -0.34115386]]) + """ + return _mx_nd_npx.log_softmax(data, length, axis=axis, temperature=temperature, + use_length=use_length, dtype=dtype) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): + r"""Applies the softmax function masking elements according to the mask provided + + Parameters + ---------- + data : NDArray + The input array. + mask : NDArray + Mask to apply. + axis : int, optional, default='-1' + The axis along which to compute softmax. + temperature : double or None, optional, default=None + Temperature parameter in softmax + dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' + DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + normalize : boolean or None, optional, default=1 + Whether to normalize input data x: x = x - max(x) + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Examples + -------- + >>> data = np.arange(5) + >>> mask = np.array([1, 0, 1, 0, 1]) + >>> npx.masked_softmax(data, mask) + array([0.01587624, 0. , 0.11731042, 0. , 0.8668133 ]) + >>> data = np.arange(10).reshape((2, 5)) + >>> npx.masked_softmax(data, mask, axis=0) + array([[0.00669285, 0. , 0.00669285, 0. , 0.00669285], + [0.9933072 , 0. , 0.9933072 , 0. , 0.9933072 ]]) + """ + return _mx_nd_npx.masked_softmax(data, mask, axis=axis, temperature=temperature, + dtype=dtype) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): + r"""Computes the masked log softmax of the input. + This is equivalent to computing masked softmax followed by log. + + Parameters + ---------- + data : NDArray + The input array. + mask : NDArray + Mask to apply. + axis : int, optional, default='-1' + The axis along which to compute softmax. + temperature : double or None, optional, default=None + Temperature parameter in softmax + dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' + DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + normalize : boolean or None, optional, default=1 + Whether to normalize input data x: x = x - max(x) + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Examples + -------- + >>> data = np.arange(5) + >>> mask = np.array([1, 0, 1, 0, 1]) + >>> npx.masked_log_softmax(data, mask) + array([-4.1429286 , -inf, -2.1429286 , -inf, -0.14292854]) + >>> data = np.arange(10).reshape((2, 5)) + >>> npx.masked_log_softmax(data, mask, axis=0) + array([[-5.0067153 , -inf, -5.0067153 , -inf, -5.0067153 ], + [-0.00671535, -inf, -0.00671535, -inf, -0.00671535]]) + """ + return _mx_nd_npx.masked_log_softmax(data, mask, axis=axis, temperature=temperature, + dtype=dtype) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def activation(data, act_type='relu', name='fwd'): + r"""Applies an activation function element-wise to the input. + + The following activation functions are supported: + + - `relu`: Rectified Linear Unit, :math:`y = max(x, 0)` + - `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}` + - `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}` + - `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))` + - `softsign`: :math:`y = \frac{x}{1 + abs(x)}` + + Parameters + ---------- + data : NDArray + The input array. + act_type : {'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required + Activation function to be applied. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _mx_nd_npx.activation(data, act_type=act_type) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def batch_norm(x, gamma, beta, running_mean, running_var, name='fwd', eps=1e-3, momentum=0.9, + fix_gamma=True, use_global_stats=False, output_mean_var=False, axis=1, + cudnn_off=False, min_calib_range=None, max_calib_range=None): + r"""Batch normalization. + + Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as + well as offset ``beta``. + + Assume the input has more than one dimension and we normalize along axis 1. + We first compute the mean and variance along this axis: + + .. math:: + + data\_mean[i] = mean(data[:,i,:,...]) \\ + data\_var[i] = var(data[:,i,:,...]) + + Then compute the normalized output, which has the same shape as input, as following: + + .. math:: + + out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} * gamma[i] + beta[i] + + Both *mean* and *var* returns a scalar by treating the input as a vector. + + Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` + have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and + the inverse of ``data_var``, which are needed for the backward pass. Note that gradient of these + two outputs are blocked. + + Besides the inputs and the outputs, this operator accepts two auxiliary + states, ``moving_mean`` and ``moving_var``, which are *k*-length + vectors. They are global statistics for the whole dataset, which are updated + by:: + + moving_mean = moving_mean * momentum + data_mean * (1 - momentum) + moving_var = moving_var * momentum + data_var * (1 - momentum) + + If ``use_global_stats`` is set to be true, then ``moving_mean`` and + ``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute + the output. It is often used during inference. + + The parameter ``axis`` specifies which axis of the input shape denotes + the 'channel' (separately normalized groups). The default is 1. Specifying -1 sets the channel + axis to be the last item in the input shape. + + Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true, + then set ``gamma`` to 1 and its gradient to 0. + + .. Note:: + When ``fix_gamma`` is set to True, no sparse support is provided. If ``fix_gamma is`` set to False, + the sparse tensors will fallback. + + Parameters + ---------- + data : NDArray + Input data to batch normalization + gamma : NDArray + gamma array + beta : NDArray + beta array + moving_mean : NDArray + running mean of input + moving_var : NDArray + running variance of input + eps : double, optional, default=0.0010000000474974513 + Epsilon to prevent div 0. Must be no less than CUDNN_BN_MIN_EPSILON defined in cudnn.h when using cudnn (usually 1e-5) + momentum : float, optional, default=0.899999976 + Momentum for moving average + fix_gamma : boolean, optional, default=1 + Fix gamma while training + use_global_stats : boolean, optional, default=0 + Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator. + output_mean_var : boolean, optional, default=0 + Output the mean and inverse std + axis : int, optional, default='1' + Specify which shape axis the channel is specified + cudnn_off : boolean, optional, default=0 + Do not select CUDNN operator, if available + min_calib_range : float or None, optional, default=None + The minimum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output. + max_calib_range : float or None, optional, default=None + The maximum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _mx_nd_npx.batch_norm(x, gamma, beta, running_mean, running_var, eps=eps, + momentum=momentum, fix_gamma=fix_gamma, + use_global_stats=use_global_stats, + output_mean_var=output_mean_var, axis=axis, cudnn_off=cudnn_off, + min_calib_range=min_calib_range, max_calib_range=max_calib_range) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def fully_connected(x, weight, name='fwd', bias=None, num_hidden=None, + no_bias=True, flatten=True): + r"""Applies a linear transformation: :math:`Y = XW^T + b`. + + If ``flatten`` is set to be true, then the shapes are: + + - **data**: `(batch_size, x1, x2, ..., xn)` + - **weight**: `(num_hidden, x1 * x2 * ... * xn)` + - **bias**: `(num_hidden,)` + - **out**: `(batch_size, num_hidden)` + + If ``flatten`` is set to be false, then the shapes are: + + - **data**: `(x1, x2, ..., xn, input_dim)` + - **weight**: `(num_hidden, input_dim)` + - **bias**: `(num_hidden,)` + - **out**: `(x1, x2, ..., xn, num_hidden)` + + The learnable parameters include both ``weight`` and ``bias``. + + If ``no_bias`` is set to be true, then the ``bias`` term is ignored. + + .. Note:: + + The sparse support for FullyConnected is limited to forward evaluation with `row_sparse` + weight and bias, where the length of `weight.indices` and `bias.indices` must be equal + to `num_hidden`. This could be useful for model inference with `row_sparse` weights + trained with importance sampling or noise contrastive estimation. + + To compute linear transformation with 'csr' sparse data, sparse.dot is recommended instead + of sparse.FullyConnected. + + Parameters + ---------- + data : NDArray + Input data. + weight : NDArray + Weight matrix. + bias : NDArray + Bias parameter. + num_hidden : int, required + Number of hidden nodes of the output. + no_bias : boolean, optional, default=0 + Whether to disable bias parameter. + flatten : boolean, optional, default=1 + Whether to collapse all but the first axis of the input data tensor. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _mx_nd_npx.fully_connected(x, weight, bias=bias, num_hidden=num_hidden, + no_bias=no_bias, flatten=flatten) diff --git a/src/api/operator/numpy_extension/npx_activation_op.cc b/src/api/operator/numpy_extension/npx_activation_op.cc new file mode 100644 index 000000000000..7e16053ffce1 --- /dev/null +++ b/src/api/operator/numpy_extension/npx_activation_op.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_activation_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_activation_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/nn/activation-inl.h" + +namespace mxnet { + +inline int String2MXNetActType(const std::string& s) { + if (s == "relu") { + return 0; + } else if (s == "sigmoid") { + return 1; + } else if (s == "tanh") { + return 2; + } else if (s == "softrelu") { + return 3; + } else if (s == "softsign") { + return 4; + } else { + LOG(FATAL) << "unknown activation type " << s; + } + LOG(FATAL) << "should not reach here "; + return -1; +} + +MXNET_REGISTER_API("_npx.activation") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_activation"); + op::ActivationParam param; + // act_type + param.act_type = String2MXNetActType(args[1].operator std::string()); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + NDArray* inputs[] = {args[0].operator NDArray*()}; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_batch_norm_op.cc b/src/api/operator/numpy_extension/npx_batch_norm_op.cc new file mode 100644 index 000000000000..f08df02ce003 --- /dev/null +++ b/src/api/operator/numpy_extension/npx_batch_norm_op.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_batch_norm_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_batch_norm_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/nn/batch_norm-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npx.batch_norm") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_batch_norm"); + op::BatchNormParam param; + // eps + param.eps = args[5].operator double(); + // momentum + param.momentum = args[6].operator double(); + // fix_gamma + param.fix_gamma = args[7].operator bool(); + // use_global_stats + param.use_global_stats = args[8].operator bool(); + // output_mean_var + param.output_mean_var = args[9].operator bool(); + // axis + param.axis= args[10].operator int(); + // cudnn_off + param.cudnn_off = args[11].operator bool(); + // min_calib_range + if (args[12].type_code() == kDLFloat || args[12].type_code() == kDLInt) { + param.min_calib_range = args[12].operator double(); + } else { + param.min_calib_range = dmlc::nullopt; + } + // max_calib_range + if (args[13].type_code() == kDLFloat || args[13].type_code() == kDLInt) { + param.max_calib_range = args[13].operator double(); + } else { + param.max_calib_range = dmlc::nullopt; + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + int num_inputs = 5; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_fully_connected_op.cc b/src/api/operator/numpy_extension/npx_fully_connected_op.cc new file mode 100644 index 000000000000..d9ab3c02c61b --- /dev/null +++ b/src/api/operator/numpy_extension/npx_fully_connected_op.cc @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_fully_connected_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_fully_connected_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/nn/fully_connected-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npx.fully_connected") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + int args_size = args.size(); + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_fully_connected"); + op::FullyConnectedParam param; + // no_bias + param.no_bias = args[args_size - 2].operator bool(); + // inputs + int num_inputs = 2; + if (param.no_bias) { + num_inputs = 2; + } else { + num_inputs = 3; + } + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // num_hidden + param.num_hidden = args[args_size - 3].operator int(); + // flatten + param.flatten = args[args_size - 1].operator bool(); + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_softmax_op.cc b/src/api/operator/numpy_extension/npx_softmax_op.cc new file mode 100644 index 000000000000..0a1ba677c9e1 --- /dev/null +++ b/src/api/operator/numpy_extension/npx_softmax_op.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_softmax_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_softmax_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/nn/softmax-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npx.softmax") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + static const nnvm::Op* op = Op::Get("_npx_softmax"); + op::SoftmaxParam param; + int args_size = args.size(); + // inputs + int num_inputs = args_size - 4; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + + // parse use_length + if (args[args_size - 2].type_code() == kNull) { + param.use_length = false; + } else { + param.use_length = args[args_size - 2].operator bool(); + } + + // parse axis + if (args[args_size - 4].type_code() == kDLInt) { + param.axis = args[args_size - 4].operator int(); + } else if (args[args_size - 4].type_code() == kDLFloat) { + param.axis = static_cast(args[args_size - 4].operator double()); + } + + // parse temperature + if (args[args_size - 3].type_code() == kNull) { + param.temperature = dmlc::nullopt; + } else { + param.temperature = args[args_size - 3].operator int64_t(); + } + + // parse dtype + if (args[args_size - 1].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[args_size - 1].operator std::string()); + } + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +MXNET_REGISTER_API("_npx.log_softmax") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + static const nnvm::Op* op = Op::Get("_npx_log_softmax"); + op::SoftmaxParam param; + + int args_size = args.size(); + // inputs + int num_inputs = args_size - 4; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + + // parse use_length + if (args[args_size - 2].type_code() == kNull) { + param.use_length = false; + } else { + param.use_length = args[args_size - 2].operator bool(); + } + + // parse axis + if (args[args_size - 4].type_code() == kDLInt) { + param.axis = args[args_size - 4].operator int(); + } else if (args[args_size - 4].type_code() == kDLFloat) { + param.axis = static_cast(args[args_size - 4].operator double()); + } + + // parse temperature + if (args[args_size - 3].type_code() == kNull) { + param.temperature = dmlc::nullopt; + } else { + param.temperature = args[args_size - 3].operator int64_t(); + } + + // parse dtype + if (args[args_size - 1].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[args_size - 1].operator std::string()); + } + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h index 06ff1fe1bedb..f90f0e84c0c5 100644 --- a/src/operator/nn/activation-inl.h +++ b/src/operator/nn/activation-inl.h @@ -69,6 +69,11 @@ struct ActivationParam : public dmlc::Parameter { bool operator==(const ActivationParam& other) const { return this->act_type == other.act_type; } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream act_type_s; + act_type_s << act_type; + (*dict)["act_type"] = act_type_s.str(); + } }; } // namespace op diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 485b3b33f6a8..bb8313d3cd0e 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -125,6 +125,28 @@ struct BatchNormParam : public dmlc::Parameter { } return flag; } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream eps_s, momentum_s, fix_gamma_s, use_global_stats_s, output_mean_var_s, + axis_s, cudnn_off_s, min_calib_range_s, max_calib_range_s; + eps_s << eps; + momentum_s << momentum; + fix_gamma_s << fix_gamma; + use_global_stats_s << use_global_stats; + output_mean_var_s << output_mean_var; + axis_s << axis; + cudnn_off_s << cudnn_off; + min_calib_range_s << min_calib_range; + max_calib_range_s << max_calib_range; + (*dict)["eps"] = eps_s.str(); + (*dict)["momentum"] = momentum_s.str(); + (*dict)["fix_gamma"] = fix_gamma_s.str(); + (*dict)["use_global_stats"] = use_global_stats_s.str(); + (*dict)["output_mean_var"] = output_mean_var_s.str(); + (*dict)["axis"] = axis_s.str(); + (*dict)["cudnn_off"] = cudnn_off_s.str(); + (*dict)["min_calib_range"] = min_calib_range_s.str(); + (*dict)["max_calib_range"] = max_calib_range_s.str(); + } }; } // namespace op diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h index c90e8ce014e7..51d6f5c8d46d 100644 --- a/src/operator/nn/fully_connected-inl.h +++ b/src/operator/nn/fully_connected-inl.h @@ -80,6 +80,15 @@ struct FullyConnectedParam : public dmlc::Parameter { this->no_bias == other.no_bias && this->flatten == other.flatten; } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream num_hidden_s, no_bias_s, flatten_s; + num_hidden_s << num_hidden; + no_bias_s << no_bias; + flatten_s << flatten; + (*dict)["num_hidden"] = num_hidden_s.str(); + (*dict)["no_bias"] = no_bias_s.str(); + (*dict)["flatten"] = flatten_s.str(); + } }; /** diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 512d8d2febbb..7f64b7426c3f 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -1179,6 +1179,21 @@ struct SoftmaxParam : public dmlc::Parameter { this->dtype == other.dtype && this->use_length == other.use_length; } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream axis_s, temperature_s, dtype_s, use_length_s; + axis_s << axis; + temperature_s << temperature; + dtype_s << dtype; + use_length_s << use_length; + (*dict)["axis"] = axis_s.str(); + (*dict)["temperature"] = temperature_s.str(); + if (dtype.has_value()) { + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype.value()); + } else { + (*dict)["dtype"] = dtype_s.str(); + } + (*dict)["use_length"] = use_length_s.str(); + } }; struct MaskedSoftmaxParam : public dmlc::Parameter { From bbf0c0d247b46ff4e57e6e8b23916df6157992c2 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Tue, 23 Mar 2021 16:15:46 -0700 Subject: [PATCH 02/13] fix lint --- python/mxnet/ndarray/numpy_extension/_op.py | 45 ++++++++++++------- python/mxnet/numpy_extension/_op.py | 42 ++++++++++------- .../numpy_extension/npx_activation_op.cc | 3 ++ .../numpy_extension/npx_batch_norm_op.cc | 3 ++ .../numpy_extension/npx_fully_connected_op.cc | 3 ++ 5 files changed, 63 insertions(+), 33 deletions(-) diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index 3dfbdb42d405..d771934a112a 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -19,7 +19,7 @@ used in Gluon dispatched by F=ndarray module.""" import numpy as _np -from .. import numpy as np +from .. import numpy as np # pylint: disable=reimported from . import _api_internal from ...util import set_module @@ -53,7 +53,8 @@ def softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtyp temperature : double or None, optional, default=None Temperature parameter in softmax dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' - DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + DType of the output in case this can't be inferred. Defaults to + the same as input's dtype if not defined (dtype=None). use_length : boolean or None, optional, default=0 Whether to use the length input as a mask over the data input. @@ -97,7 +98,8 @@ def log_softmax(data, length=None, axis=-1, temperature=None, use_length=False, temperature : double or None, optional, default=None Temperature parameter in softmax dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' - DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + DType of the output in case this can't be inferred. Defaults to + the same as input's dtype if not defined (dtype=None). use_length : boolean or None, optional, default=0 Whether to use the length input as a mask over the data input. @@ -142,7 +144,8 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): temperature : double or None, optional, default=None Temperature parameter in softmax dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' - DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + DType of the output in case this can't be inferred. Defaults to + the same as input's dtype if not defined (dtype=None). normalize : boolean or None, optional, default=1 Whether to normalize input data x: x = x - max(x) @@ -190,7 +193,8 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): temperature : double or None, optional, default=None Temperature parameter in softmax dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' - DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + DType of the output in case this can't be inferred. Defaults to + the same as input's dtype if not defined (dtype=None). normalize : boolean or None, optional, default=1 Whether to normalize input data x: x = x - max(x) @@ -224,7 +228,7 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def activation(data, act_type='relu', name='fwd'): +def activation(data, act_type='relu', name=None): r"""Applies an activation function element-wise to the input. The following activation functions are supported: @@ -247,12 +251,12 @@ def activation(data, act_type='relu', name='fwd'): out : NDArray or list of NDArrays The output of this function. """ - return _api_internal.activation(data, act_type) + return _api_internal.activation(data, act_type, name) # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def batch_norm(x, gamma, beta, running_mean, running_var, name='fwd', eps=1e-3, momentum=0.9, +def batch_norm(x, gamma, beta, running_mean, running_var, name=None, eps=1e-3, momentum=0.9, fix_gamma=True, use_global_stats=False, output_mean_var=False, axis=1, cudnn_off=False, min_calib_range=None, max_calib_range=None): r"""Batch normalization. @@ -317,23 +321,29 @@ def batch_norm(x, gamma, beta, running_mean, running_var, name='fwd', eps=1e-3, moving_var : NDArray running variance of input eps : double, optional, default=0.0010000000474974513 - Epsilon to prevent div 0. Must be no less than CUDNN_BN_MIN_EPSILON defined in cudnn.h when using cudnn (usually 1e-5) + Epsilon to prevent div 0. Must be no less than CUDNN_BN_MIN_EPSILON + defined in cudnn.h when using cudnn (usually 1e-5) momentum : float, optional, default=0.899999976 Momentum for moving average fix_gamma : boolean, optional, default=1 Fix gamma while training use_global_stats : boolean, optional, default=0 - Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator. + Whether use global moving statistics instead of local batch-norm. + This will force change batch-norm into a scale shift operator. output_mean_var : boolean, optional, default=0 - Output the mean and inverse std + Output the mean and inverse std axis : int, optional, default='1' Specify which shape axis the channel is specified cudnn_off : boolean, optional, default=0 Do not select CUDNN operator, if available min_calib_range : float or None, optional, default=None - The minimum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output. + The minimum scalar value in the form of float32 obtained through calibration. + If present, it will be used to by quantized batch norm op to calculate primitive scale. + Note: this calib_range is to calib bn output. max_calib_range : float or None, optional, default=None - The maximum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output. + The maximum scalar value in the form of float32 obtained through calibration. + If present, it will be used to by quantized batch norm op to calculate primitive scale. + Note: this calib_range is to calib bn output. Returns ------- @@ -342,12 +352,12 @@ def batch_norm(x, gamma, beta, running_mean, running_var, name='fwd', eps=1e-3, """ return _api_internal.batch_norm(x, gamma, beta, running_mean, running_var, eps, momentum, fix_gamma, use_global_stats, output_mean_var, axis, - cudnn_off, min_calib_range, max_calib_range) + cudnn_off, min_calib_range, max_calib_range, name) # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def fully_connected(x, weight, name='fwd', bias=None, num_hidden=None, +def fully_connected(x, weight, name=None, bias=None, num_hidden=None, no_bias=True, flatten=True): r"""Applies a linear transformation: :math:`Y = XW^T + b`. @@ -401,7 +411,8 @@ def fully_connected(x, weight, name='fwd', bias=None, num_hidden=None, """ assert num_hidden is not None, "Please provide number of hidden layers" if no_bias: - return _api_internal.fully_connected(x, weight, num_hidden, no_bias, flatten) + return _api_internal.fully_connected(x, weight, name, num_hidden, no_bias, flatten) else: assert bias is not None, "Missing bias input" - return _api_internal.fully_connected(x, weight, bias, num_hidden, no_bias, flatten) + return _api_internal.fully_connected(x, weight, bias, name, num_hidden, + no_bias, flatten) diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index f7cc016fd1b8..d831b2bbfb73 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -50,7 +50,8 @@ def softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtyp temperature : double or None, optional, default=None Temperature parameter in softmax dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' - DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + DType of the output in case this can't be inferred. Defaults to + the same as input's dtype if not defined (dtype=None). use_length : boolean or None, optional, default=0 Whether to use the length input as a mask over the data input. @@ -88,7 +89,8 @@ def log_softmax(data, length=None, axis=-1, temperature=None, use_length=False, temperature : double or None, optional, default=None Temperature parameter in softmax dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' - DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + DType of the output in case this can't be inferred. Defaults to + the same as input's dtype if not defined (dtype=None). use_length : boolean or None, optional, default=0 Whether to use the length input as a mask over the data input. @@ -127,7 +129,8 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): temperature : double or None, optional, default=None Temperature parameter in softmax dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' - DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + DType of the output in case this can't be inferred. Defaults to + the same as input's dtype if not defined (dtype=None). normalize : boolean or None, optional, default=1 Whether to normalize input data x: x = x - max(x) @@ -168,7 +171,8 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): temperature : double or None, optional, default=None Temperature parameter in softmax dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' - DType of the output in case this can't be inferred. Defaults to the same as input's dtype if not defined (dtype=None). + DType of the output in case this can't be inferred. Defaults to + the same as input's dtype if not defined (dtype=None). normalize : boolean or None, optional, default=1 Whether to normalize input data x: x = x - max(x) @@ -194,7 +198,7 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def activation(data, act_type='relu', name='fwd'): +def activation(data, act_type='relu', name=None): r"""Applies an activation function element-wise to the input. The following activation functions are supported: @@ -217,12 +221,12 @@ def activation(data, act_type='relu', name='fwd'): out : NDArray or list of NDArrays The output of this function. """ - return _mx_nd_npx.activation(data, act_type=act_type) + return _mx_nd_npx.activation(data, act_type=act_type, name=name) # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def batch_norm(x, gamma, beta, running_mean, running_var, name='fwd', eps=1e-3, momentum=0.9, +def batch_norm(x, gamma, beta, running_mean, running_var, name=None, eps=1e-3, momentum=0.9, fix_gamma=True, use_global_stats=False, output_mean_var=False, axis=1, cudnn_off=False, min_calib_range=None, max_calib_range=None): r"""Batch normalization. @@ -287,39 +291,45 @@ def batch_norm(x, gamma, beta, running_mean, running_var, name='fwd', eps=1e-3, moving_var : NDArray running variance of input eps : double, optional, default=0.0010000000474974513 - Epsilon to prevent div 0. Must be no less than CUDNN_BN_MIN_EPSILON defined in cudnn.h when using cudnn (usually 1e-5) + Epsilon to prevent div 0. Must be no less than CUDNN_BN_MIN_EPSILON + defined in cudnn.h when using cudnn (usually 1e-5) momentum : float, optional, default=0.899999976 Momentum for moving average fix_gamma : boolean, optional, default=1 Fix gamma while training use_global_stats : boolean, optional, default=0 - Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator. + Whether use global moving statistics instead of local batch-norm. + This will force change batch-norm into a scale shift operator. output_mean_var : boolean, optional, default=0 - Output the mean and inverse std + Output the mean and inverse std axis : int, optional, default='1' Specify which shape axis the channel is specified cudnn_off : boolean, optional, default=0 Do not select CUDNN operator, if available min_calib_range : float or None, optional, default=None - The minimum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output. + The minimum scalar value in the form of float32 obtained through calibration. + If present, it will be used to by quantized batch norm op to calculate primitive scale. + Note: this calib_range is to calib bn output. max_calib_range : float or None, optional, default=None - The maximum scalar value in the form of float32 obtained through calibration. If present, it will be used to by quantized batch norm op to calculate primitive scale.Note: this calib_range is to calib bn output. + The maximum scalar value in the form of float32 obtained through calibration. + If present, it will be used to by quantized batch norm op to calculate primitive scale. + Note: this calib_range is to calib bn output. Returns ------- out : NDArray or list of NDArrays The output of this function. """ - return _mx_nd_npx.batch_norm(x, gamma, beta, running_mean, running_var, eps=eps, + return _mx_nd_npx.batch_norm(x, gamma, beta, running_mean, running_var,eps=eps, momentum=momentum, fix_gamma=fix_gamma, - use_global_stats=use_global_stats, + use_global_stats=use_global_stats, name=name, output_mean_var=output_mean_var, axis=axis, cudnn_off=cudnn_off, min_calib_range=min_calib_range, max_calib_range=max_calib_range) # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def fully_connected(x, weight, name='fwd', bias=None, num_hidden=None, +def fully_connected(x, weight, name=None, bias=None, num_hidden=None, no_bias=True, flatten=True): r"""Applies a linear transformation: :math:`Y = XW^T + b`. @@ -372,4 +382,4 @@ def fully_connected(x, weight, name='fwd', bias=None, num_hidden=None, The output of this function. """ return _mx_nd_npx.fully_connected(x, weight, bias=bias, num_hidden=num_hidden, - no_bias=no_bias, flatten=flatten) + no_bias=no_bias, flatten=flatten, name=name) diff --git a/src/api/operator/numpy_extension/npx_activation_op.cc b/src/api/operator/numpy_extension/npx_activation_op.cc index 7e16053ffce1..aaf0c61e252b 100644 --- a/src/api/operator/numpy_extension/npx_activation_op.cc +++ b/src/api/operator/numpy_extension/npx_activation_op.cc @@ -57,6 +57,9 @@ MXNET_REGISTER_API("_npx.activation") attrs.parsed = param; attrs.op = op; SetAttrDict(&attrs); + if (args[2].type_code() != kNull) { + attrs.dict["name"] = args[2].operator std::string(); + } // inputs NDArray* inputs[] = {args[0].operator NDArray*()}; int num_inputs = 1; diff --git a/src/api/operator/numpy_extension/npx_batch_norm_op.cc b/src/api/operator/numpy_extension/npx_batch_norm_op.cc index f08df02ce003..28bcc1a50d22 100644 --- a/src/api/operator/numpy_extension/npx_batch_norm_op.cc +++ b/src/api/operator/numpy_extension/npx_batch_norm_op.cc @@ -63,6 +63,9 @@ MXNET_REGISTER_API("_npx.batch_norm") attrs.parsed = param; attrs.op = op; SetAttrDict(&attrs); + if (args[14].type_code() != kNull) { + attrs.dict["name"] = args[14].operator std::string(); + } // inputs int num_inputs = 5; std::vector inputs; diff --git a/src/api/operator/numpy_extension/npx_fully_connected_op.cc b/src/api/operator/numpy_extension/npx_fully_connected_op.cc index d9ab3c02c61b..955b255b1bbe 100644 --- a/src/api/operator/numpy_extension/npx_fully_connected_op.cc +++ b/src/api/operator/numpy_extension/npx_fully_connected_op.cc @@ -57,6 +57,9 @@ MXNET_REGISTER_API("_npx.fully_connected") attrs.parsed = param; attrs.op = op; SetAttrDict(&attrs); + if (args[args_size - 4].type_code() != kNull) { + attrs.dict["name"] = args[args_size - 4].operator std::string(); + } int num_outputs = 0; auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); From 6f63084108a09542cec0c5e85891b4f1c6fb1d78 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Tue, 23 Mar 2021 16:30:32 -0700 Subject: [PATCH 03/13] fix lint --- python/mxnet/numpy_extension/_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index d831b2bbfb73..3b6f080ba74a 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -320,7 +320,7 @@ def batch_norm(x, gamma, beta, running_mean, running_var, name=None, eps=1e-3, m out : NDArray or list of NDArrays The output of this function. """ - return _mx_nd_npx.batch_norm(x, gamma, beta, running_mean, running_var,eps=eps, + return _mx_nd_npx.batch_norm(x, gamma, beta, running_mean, running_var, eps=eps, momentum=momentum, fix_gamma=fix_gamma, use_global_stats=use_global_stats, name=name, output_mean_var=output_mean_var, axis=axis, cudnn_off=cudnn_off, From fa1304ed12cd4001eb45ccabf90e686099b34605 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Tue, 23 Mar 2021 16:45:32 -0700 Subject: [PATCH 04/13] fix sanity --- src/api/operator/numpy_extension/npx_batch_norm_op.cc | 2 +- src/api/operator/numpy_extension/npx_softmax_op.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/api/operator/numpy_extension/npx_batch_norm_op.cc b/src/api/operator/numpy_extension/npx_batch_norm_op.cc index 28bcc1a50d22..aef27ac066ef 100644 --- a/src/api/operator/numpy_extension/npx_batch_norm_op.cc +++ b/src/api/operator/numpy_extension/npx_batch_norm_op.cc @@ -45,7 +45,7 @@ MXNET_REGISTER_API("_npx.batch_norm") // output_mean_var param.output_mean_var = args[9].operator bool(); // axis - param.axis= args[10].operator int(); + param.axis = args[10].operator int(); // cudnn_off param.cudnn_off = args[11].operator bool(); // min_calib_range diff --git a/src/api/operator/numpy_extension/npx_softmax_op.cc b/src/api/operator/numpy_extension/npx_softmax_op.cc index 0a1ba677c9e1..d9dc80aa6ff7 100644 --- a/src/api/operator/numpy_extension/npx_softmax_op.cc +++ b/src/api/operator/numpy_extension/npx_softmax_op.cc @@ -86,7 +86,7 @@ MXNET_REGISTER_API("_npx.log_softmax") nnvm::NodeAttrs attrs; static const nnvm::Op* op = Op::Get("_npx_log_softmax"); op::SoftmaxParam param; - + int args_size = args.size(); // inputs int num_inputs = args_size - 4; From 0bad5330f003df1860426e587e95fb8db60117bd Mon Sep 17 00:00:00 2001 From: barry-jin Date: Tue, 23 Mar 2021 17:19:06 -0700 Subject: [PATCH 05/13] update softmax --- src/api/operator/numpy_extension/npx_softmax_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/api/operator/numpy_extension/npx_softmax_op.cc b/src/api/operator/numpy_extension/npx_softmax_op.cc index d9dc80aa6ff7..641129e00ae9 100644 --- a/src/api/operator/numpy_extension/npx_softmax_op.cc +++ b/src/api/operator/numpy_extension/npx_softmax_op.cc @@ -53,7 +53,7 @@ MXNET_REGISTER_API("_npx.softmax") // parse axis if (args[args_size - 4].type_code() == kDLInt) { param.axis = args[args_size - 4].operator int(); - } else if (args[args_size - 4].type_code() == kDLFloat) { + } else { param.axis = static_cast(args[args_size - 4].operator double()); } @@ -106,7 +106,7 @@ MXNET_REGISTER_API("_npx.log_softmax") // parse axis if (args[args_size - 4].type_code() == kDLInt) { param.axis = args[args_size - 4].operator int(); - } else if (args[args_size - 4].type_code() == kDLFloat) { + } else { param.axis = static_cast(args[args_size - 4].operator double()); } From 6f71af396f875985f507813285fc64a78b2bd5b4 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Wed, 24 Mar 2021 09:24:39 -0700 Subject: [PATCH 06/13] fix fully_connected --- python/mxnet/ndarray/numpy_extension/_op.py | 8 ++++---- python/mxnet/numpy_extension/_op.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index d771934a112a..ff90e588c0b5 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -357,8 +357,8 @@ def batch_norm(x, gamma, beta, running_mean, running_var, name=None, eps=1e-3, m # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def fully_connected(x, weight, name=None, bias=None, num_hidden=None, - no_bias=True, flatten=True): +def fully_connected(x, weight, bias, num_hidden=None, + no_bias=True, flatten=True, name=None): r"""Applies a linear transformation: :math:`Y = XW^T + b`. If ``flatten`` is set to be true, then the shapes are: @@ -409,10 +409,10 @@ def fully_connected(x, weight, name=None, bias=None, num_hidden=None, out : NDArray or list of NDArrays The output of this function. """ - assert num_hidden is not None, "Please provide number of hidden layers" + assert num_hidden is not None, "Please provide number of hidden nodes" if no_bias: return _api_internal.fully_connected(x, weight, name, num_hidden, no_bias, flatten) else: - assert bias is not None, "Missing bias input" + assert bias is not None, "Missing bias parameter" return _api_internal.fully_connected(x, weight, bias, name, num_hidden, no_bias, flatten) diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index 3b6f080ba74a..8a1551be0a53 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -329,8 +329,8 @@ def batch_norm(x, gamma, beta, running_mean, running_var, name=None, eps=1e-3, m # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def fully_connected(x, weight, name=None, bias=None, num_hidden=None, - no_bias=True, flatten=True): +def fully_connected(x, weight, bias, num_hidden=None, + no_bias=True, flatten=True, name=None): r"""Applies a linear transformation: :math:`Y = XW^T + b`. If ``flatten`` is set to be true, then the shapes are: @@ -381,5 +381,5 @@ def fully_connected(x, weight, name=None, bias=None, num_hidden=None, out : NDArray or list of NDArrays The output of this function. """ - return _mx_nd_npx.fully_connected(x, weight, bias=bias, num_hidden=num_hidden, + return _mx_nd_npx.fully_connected(x, weight, bias, num_hidden=num_hidden, no_bias=no_bias, flatten=flatten, name=name) From 13d4e9cb05041adb82e0a7e54399ac254deb67c4 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Wed, 24 Mar 2021 10:24:31 -0700 Subject: [PATCH 07/13] fix fully_connected --- python/mxnet/ndarray/numpy_extension/_op.py | 2 +- python/mxnet/numpy_extension/_op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index ff90e588c0b5..f6564e3bc5f7 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -357,7 +357,7 @@ def batch_norm(x, gamma, beta, running_mean, running_var, name=None, eps=1e-3, m # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def fully_connected(x, weight, bias, num_hidden=None, +def fully_connected(x, weight, bias=None, num_hidden=None, no_bias=True, flatten=True, name=None): r"""Applies a linear transformation: :math:`Y = XW^T + b`. diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index 8a1551be0a53..e2ace04bca49 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -329,7 +329,7 @@ def batch_norm(x, gamma, beta, running_mean, running_var, name=None, eps=1e-3, m # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def fully_connected(x, weight, bias, num_hidden=None, +def fully_connected(x, weight, bias=None, num_hidden=None, no_bias=True, flatten=True, name=None): r"""Applies a linear transformation: :math:`Y = XW^T + b`. From ba54fe853e989abea6d7eaa06d01164ade32eb8b Mon Sep 17 00:00:00 2001 From: barry-jin Date: Wed, 24 Mar 2021 12:05:15 -0700 Subject: [PATCH 08/13] update --- python/mxnet/ndarray/numpy_extension/_op.py | 26 ++++++++++--------- python/mxnet/numpy_extension/_op.py | 24 +++++++++-------- .../numpy_extension/npx_activation_op.cc | 3 --- .../numpy_extension/npx_batch_norm_op.cc | 3 --- .../numpy_extension/npx_fully_connected_op.cc | 3 --- 5 files changed, 27 insertions(+), 32 deletions(-) diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index f6564e3bc5f7..bbbb288641b1 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -30,7 +30,7 @@ # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtype=None): +def softmax(data, axis=-1, length=None, temperature=None, use_length=False, dtype=None): r"""Applies the softmax function. The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1. @@ -46,10 +46,10 @@ def softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtyp ---------- data : NDArray The input array. - length : NDArray - The length array. axis : int, optional, default='-1' The axis along which to compute softmax. + length : NDArray + The length array. temperature : double or None, optional, default=None Temperature parameter in softmax dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' @@ -85,7 +85,7 @@ def softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtyp # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def log_softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtype=None): +def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False, dtype=None): r"""Computes the log softmax of the input. This is equivalent to computing softmax followed by log. @@ -95,6 +95,8 @@ def log_softmax(data, length=None, axis=-1, temperature=None, use_length=False, The input array. axis : int, optional, default='-1' The axis along which to compute softmax. + length : NDArray + The length array. temperature : double or None, optional, default=None Temperature parameter in softmax dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' @@ -228,7 +230,7 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def activation(data, act_type='relu', name=None): +def activation(data, act_type='relu', **kwargs): r"""Applies an activation function element-wise to the input. The following activation functions are supported: @@ -251,14 +253,14 @@ def activation(data, act_type='relu', name=None): out : NDArray or list of NDArrays The output of this function. """ - return _api_internal.activation(data, act_type, name) + return _api_internal.activation(data, act_type) # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def batch_norm(x, gamma, beta, running_mean, running_var, name=None, eps=1e-3, momentum=0.9, +def batch_norm(x, gamma, beta, running_mean, running_var, eps=1e-3, momentum=0.9, fix_gamma=True, use_global_stats=False, output_mean_var=False, axis=1, - cudnn_off=False, min_calib_range=None, max_calib_range=None): + cudnn_off=False, min_calib_range=None, max_calib_range=None, **kwargs): r"""Batch normalization. Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as @@ -352,13 +354,13 @@ def batch_norm(x, gamma, beta, running_mean, running_var, name=None, eps=1e-3, m """ return _api_internal.batch_norm(x, gamma, beta, running_mean, running_var, eps, momentum, fix_gamma, use_global_stats, output_mean_var, axis, - cudnn_off, min_calib_range, max_calib_range, name) + cudnn_off, min_calib_range, max_calib_range) # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') def fully_connected(x, weight, bias=None, num_hidden=None, - no_bias=True, flatten=True, name=None): + no_bias=True, flatten=True, **kwargs): r"""Applies a linear transformation: :math:`Y = XW^T + b`. If ``flatten`` is set to be true, then the shapes are: @@ -411,8 +413,8 @@ def fully_connected(x, weight, bias=None, num_hidden=None, """ assert num_hidden is not None, "Please provide number of hidden nodes" if no_bias: - return _api_internal.fully_connected(x, weight, name, num_hidden, no_bias, flatten) + return _api_internal.fully_connected(x, weight, num_hidden, no_bias, flatten) else: assert bias is not None, "Missing bias parameter" - return _api_internal.fully_connected(x, weight, bias, name, num_hidden, + return _api_internal.fully_connected(x, weight, bias, num_hidden, no_bias, flatten) diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index e2ace04bca49..8609b9b71c45 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -27,7 +27,7 @@ # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtype=None): +def softmax(data, axis=-1, length=None, temperature=None, use_length=False, dtype=None): r"""Applies the softmax function. The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1. @@ -43,10 +43,10 @@ def softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtyp ---------- data : NDArray The input array. - length : NDArray - The length array. axis : int, optional, default='-1' The axis along which to compute softmax. + length : NDArray + The length array. temperature : double or None, optional, default=None Temperature parameter in softmax dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' @@ -76,7 +76,7 @@ def softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtyp # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def log_softmax(data, length=None, axis=-1, temperature=None, use_length=False, dtype=None): +def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False, dtype=None): r"""Computes the log softmax of the input. This is equivalent to computing softmax followed by log. @@ -86,6 +86,8 @@ def log_softmax(data, length=None, axis=-1, temperature=None, use_length=False, The input array. axis : int, optional, default='-1' The axis along which to compute softmax. + length : NDArray + The length array. temperature : double or None, optional, default=None Temperature parameter in softmax dtype : {None, 'float16', 'float32', 'float64'},optional, default='None' @@ -198,7 +200,7 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def activation(data, act_type='relu', name=None): +def activation(data, act_type='relu', **kwargs): r"""Applies an activation function element-wise to the input. The following activation functions are supported: @@ -221,14 +223,14 @@ def activation(data, act_type='relu', name=None): out : NDArray or list of NDArrays The output of this function. """ - return _mx_nd_npx.activation(data, act_type=act_type, name=name) + return _mx_nd_npx.activation(data, act_type=act_type) # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def batch_norm(x, gamma, beta, running_mean, running_var, name=None, eps=1e-3, momentum=0.9, +def batch_norm(x, gamma, beta, running_mean, running_var, eps=1e-3, momentum=0.9, fix_gamma=True, use_global_stats=False, output_mean_var=False, axis=1, - cudnn_off=False, min_calib_range=None, max_calib_range=None): + cudnn_off=False, min_calib_range=None, max_calib_range=None, **kwargs): r"""Batch normalization. Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as @@ -322,7 +324,7 @@ def batch_norm(x, gamma, beta, running_mean, running_var, name=None, eps=1e-3, m """ return _mx_nd_npx.batch_norm(x, gamma, beta, running_mean, running_var, eps=eps, momentum=momentum, fix_gamma=fix_gamma, - use_global_stats=use_global_stats, name=name, + use_global_stats=use_global_stats, output_mean_var=output_mean_var, axis=axis, cudnn_off=cudnn_off, min_calib_range=min_calib_range, max_calib_range=max_calib_range) @@ -330,7 +332,7 @@ def batch_norm(x, gamma, beta, running_mean, running_var, name=None, eps=1e-3, m # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') def fully_connected(x, weight, bias=None, num_hidden=None, - no_bias=True, flatten=True, name=None): + no_bias=True, flatten=True, **kwargs): r"""Applies a linear transformation: :math:`Y = XW^T + b`. If ``flatten`` is set to be true, then the shapes are: @@ -382,4 +384,4 @@ def fully_connected(x, weight, bias=None, num_hidden=None, The output of this function. """ return _mx_nd_npx.fully_connected(x, weight, bias, num_hidden=num_hidden, - no_bias=no_bias, flatten=flatten, name=name) + no_bias=no_bias, flatten=flatten) diff --git a/src/api/operator/numpy_extension/npx_activation_op.cc b/src/api/operator/numpy_extension/npx_activation_op.cc index aaf0c61e252b..7e16053ffce1 100644 --- a/src/api/operator/numpy_extension/npx_activation_op.cc +++ b/src/api/operator/numpy_extension/npx_activation_op.cc @@ -57,9 +57,6 @@ MXNET_REGISTER_API("_npx.activation") attrs.parsed = param; attrs.op = op; SetAttrDict(&attrs); - if (args[2].type_code() != kNull) { - attrs.dict["name"] = args[2].operator std::string(); - } // inputs NDArray* inputs[] = {args[0].operator NDArray*()}; int num_inputs = 1; diff --git a/src/api/operator/numpy_extension/npx_batch_norm_op.cc b/src/api/operator/numpy_extension/npx_batch_norm_op.cc index aef27ac066ef..dcf3ac4f0df7 100644 --- a/src/api/operator/numpy_extension/npx_batch_norm_op.cc +++ b/src/api/operator/numpy_extension/npx_batch_norm_op.cc @@ -63,9 +63,6 @@ MXNET_REGISTER_API("_npx.batch_norm") attrs.parsed = param; attrs.op = op; SetAttrDict(&attrs); - if (args[14].type_code() != kNull) { - attrs.dict["name"] = args[14].operator std::string(); - } // inputs int num_inputs = 5; std::vector inputs; diff --git a/src/api/operator/numpy_extension/npx_fully_connected_op.cc b/src/api/operator/numpy_extension/npx_fully_connected_op.cc index 955b255b1bbe..d9ab3c02c61b 100644 --- a/src/api/operator/numpy_extension/npx_fully_connected_op.cc +++ b/src/api/operator/numpy_extension/npx_fully_connected_op.cc @@ -57,9 +57,6 @@ MXNET_REGISTER_API("_npx.fully_connected") attrs.parsed = param; attrs.op = op; SetAttrDict(&attrs); - if (args[args_size - 4].type_code() != kNull) { - attrs.dict["name"] = args[args_size - 4].operator std::string(); - } int num_outputs = 0; auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); From bfe2c931427c76e409bd11fe4053a609381a288c Mon Sep 17 00:00:00 2001 From: barry-jin Date: Wed, 24 Mar 2021 12:15:34 -0700 Subject: [PATCH 09/13] fix lint --- python/mxnet/ndarray/numpy_extension/_op.py | 6 +++--- python/mxnet/numpy_extension/_op.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index bbbb288641b1..8ada24f77039 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -228,7 +228,7 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): return logits -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments, unused-argument @set_module('mxnet.ndarray.numpy_extension') def activation(data, act_type='relu', **kwargs): r"""Applies an activation function element-wise to the input. @@ -256,7 +256,7 @@ def activation(data, act_type='relu', **kwargs): return _api_internal.activation(data, act_type) -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments, unused-argument @set_module('mxnet.ndarray.numpy_extension') def batch_norm(x, gamma, beta, running_mean, running_var, eps=1e-3, momentum=0.9, fix_gamma=True, use_global_stats=False, output_mean_var=False, axis=1, @@ -357,7 +357,7 @@ def batch_norm(x, gamma, beta, running_mean, running_var, eps=1e-3, momentum=0.9 cudnn_off, min_calib_range, max_calib_range) -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments, unused-argument @set_module('mxnet.ndarray.numpy_extension') def fully_connected(x, weight, bias=None, num_hidden=None, no_bias=True, flatten=True, **kwargs): diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index 8609b9b71c45..c66751b2ef22 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -198,7 +198,7 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None): dtype=dtype) -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments, unused-argument @set_module('mxnet.numpy_extension') def activation(data, act_type='relu', **kwargs): r"""Applies an activation function element-wise to the input. @@ -226,7 +226,7 @@ def activation(data, act_type='relu', **kwargs): return _mx_nd_npx.activation(data, act_type=act_type) -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments, unused-argument @set_module('mxnet.numpy_extension') def batch_norm(x, gamma, beta, running_mean, running_var, eps=1e-3, momentum=0.9, fix_gamma=True, use_global_stats=False, output_mean_var=False, axis=1, @@ -329,7 +329,7 @@ def batch_norm(x, gamma, beta, running_mean, running_var, eps=1e-3, momentum=0.9 min_calib_range=min_calib_range, max_calib_range=max_calib_range) -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments, unused-argument @set_module('mxnet.numpy_extension') def fully_connected(x, weight, bias=None, num_hidden=None, no_bias=True, flatten=True, **kwargs): From c2c6412de582bffa19014145511c024f94e9784a Mon Sep 17 00:00:00 2001 From: barry-jin Date: Wed, 24 Mar 2021 12:23:34 -0700 Subject: [PATCH 10/13] fix lint --- python/mxnet/numpy_extension/_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index c66751b2ef22..d168af6b10aa 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -70,7 +70,7 @@ def softmax(data, axis=-1, length=None, temperature=None, use_length=False, dtyp array([[0.33333334, 0.33333334, 0.33333334], [0.33333334, 0.33333334, 0.33333334]]) """ - return _mx_nd_npx.softmax(data, length, axis=axis, temperature=temperature, + return _mx_nd_npx.softmax(data, axis=axis, length=length, temperature=temperature, use_length=use_length, dtype=dtype) @@ -111,7 +111,7 @@ def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False, array([[-0.34115386, -0.6931472 , -1.2411538 ], [-1.2411538 , -0.6931472 , -0.34115386]]) """ - return _mx_nd_npx.log_softmax(data, length, axis=axis, temperature=temperature, + return _mx_nd_npx.log_softmax(data, axis=axis, length=length, temperature=temperature, use_length=use_length, dtype=dtype) From 7b1f947a045c5c0b0f44e2f8f781410928d8aaae Mon Sep 17 00:00:00 2001 From: barry-jin Date: Wed, 24 Mar 2021 14:15:00 -0700 Subject: [PATCH 11/13] apply activation enum --- .../operator/numpy_extension/npx_activation_op.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/api/operator/numpy_extension/npx_activation_op.cc b/src/api/operator/numpy_extension/npx_activation_op.cc index 7e16053ffce1..e7b09fcc07c5 100644 --- a/src/api/operator/numpy_extension/npx_activation_op.cc +++ b/src/api/operator/numpy_extension/npx_activation_op.cc @@ -30,20 +30,20 @@ namespace mxnet { inline int String2MXNetActType(const std::string& s) { if (s == "relu") { - return 0; + return activation::kReLU; } else if (s == "sigmoid") { - return 1; + return activation::kSigmoid; } else if (s == "tanh") { - return 2; + return activation::kTanh; } else if (s == "softrelu") { - return 3; + return activation::kSoftReLU; } else if (s == "softsign") { - return 4; + return activation::kSoftSign; } else { LOG(FATAL) << "unknown activation type " << s; } LOG(FATAL) << "should not reach here "; - return -1; + return 0; } MXNET_REGISTER_API("_npx.activation") From 3e8c6d5ec74897819d80da9343dbdd47511d626a Mon Sep 17 00:00:00 2001 From: barry-jin Date: Wed, 24 Mar 2021 14:15:53 -0700 Subject: [PATCH 12/13] apply activation enum --- src/api/operator/numpy_extension/npx_activation_op.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/api/operator/numpy_extension/npx_activation_op.cc b/src/api/operator/numpy_extension/npx_activation_op.cc index e7b09fcc07c5..c072f6e9fc70 100644 --- a/src/api/operator/numpy_extension/npx_activation_op.cc +++ b/src/api/operator/numpy_extension/npx_activation_op.cc @@ -29,6 +29,7 @@ namespace mxnet { inline int String2MXNetActType(const std::string& s) { + using namespace op; if (s == "relu") { return activation::kReLU; } else if (s == "sigmoid") { From 16a941be4e1dffe677e15cf2c751c312ca39733c Mon Sep 17 00:00:00 2001 From: barry-jin Date: Wed, 24 Mar 2021 16:36:14 -0700 Subject: [PATCH 13/13] add act type conversion --- src/operator/nn/activation-inl.h | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h index f90f0e84c0c5..1111464b9697 100644 --- a/src/operator/nn/activation-inl.h +++ b/src/operator/nn/activation-inl.h @@ -69,10 +69,28 @@ struct ActivationParam : public dmlc::Parameter { bool operator==(const ActivationParam& other) const { return this->act_type == other.act_type; } + std::string MXNetActType2String(int act_type) { + switch (act_type) { + case activation::kReLU: + return "relu"; + case activation::kSigmoid: + return "sigmoid"; + case activation::kTanh: + return "tanh"; + case activation::kSoftReLU: + return "softrelu"; + case activation::kSoftSign: + return "softsign"; + default: + LOG(FATAL) << "Unknown act_type enum " << act_type; + } + LOG(FATAL) << "should not reach here "; + return ""; + } void SetAttrDict(std::unordered_map* dict) { std::ostringstream act_type_s; act_type_s << act_type; - (*dict)["act_type"] = act_type_s.str(); + (*dict)["act_type"] = MXNetActType2String(act_type); } };