diff --git a/.github/workflows/os_x_staticbuild.yml b/.github/workflows/os_x_staticbuild.yml index 019069ac32e6..37b28f3b012e 100644 --- a/.github/workflows/os_x_staticbuild.yml +++ b/.github/workflows/os_x_staticbuild.yml @@ -54,3 +54,35 @@ jobs: python3 -m pytest -n 4 --durations=50 --verbose tests/python/unittest/ -k 'not test_operator and not (test_subgraph or test_custom_op or test_external_op or test_recordimage_dataset_with_data_loader_multiworker or test_multi_worker or test_multi_worker_shape or test_multi_worker_forked_data_loader or test_multi_worker_dataloader_release_pool)' -m 'not serial' MXNET_ENGINE_TYPE=NaiveEngine python3 -m pytest -n 4 --durations=50 --verbose tests/python/unittest/ -k 'test_operator and not (test_subgraph or test_custom_op or test_external_op or test_recordimage_dataset_with_data_loader_multiworker or test_multi_worker or test_multi_worker_shape or test_multi_worker_forked_data_loader or test_multi_worker_dataloader_release_pool)' -m 'not serial' python3 -m pytest --durations=50 --verbose tests/python/unittest/ -k 'not (test_subgraph or test_custom_op or test_external_op or test_recordimage_dataset_with_data_loader_multiworker or test_multi_worker or test_multi_worker_shape or test_multi_worker_forked_data_loader or test_multi_worker_dataloader_release_pool)' -m 'serial' + + - name: Test Array API + env: + MXNET_ENFORCE_CYTHON: 0 + run: | + cd .. + git clone https://github.com/data-apis/array-api-tests.git + cd array-api-tests + git checkout c1dba80a196a03f880d2e0a998a272fb3867b720 + export ARRAY_API_TESTS_MODULE=mxnet.numpy pytest + export DMLC_LOG_STACK_TRACE_DEPTH=100 + python3 -m pytest --reruns 3 --durations=50 --verbose array_api_tests/test_creation_functions.py + python3 -m pytest --reruns 3 --durations=50 --verbose array_api_tests/test_indexing.py + python3 -m pytest --reruns 3 --durations=50 --verbose array_api_tests/test_constants.py + python3 -m pytest --reruns 3 --durations=50 --verbose array_api_tests/test_elementwise_functions.py + python3 -m pytest --reruns 3 --durations=50 --verbose array_api_tests/test_broadcasting.py + python3 -m pytest --reruns 3 --durations=50 --verbose \ + array_api_tests/test_type_promotion.py::test_elementwise_function_two_arg_bool_type_promotion + python3 -m pytest --reruns 3 --durations=50 --verbose \ + array_api_tests/test_type_promotion.py::test_elementwise_function_two_arg_promoted_type_promotion + python3 -m pytest --reruns 3 --durations=50 --verbose \ + array_api_tests/test_type_promotion.py::test_elementwise_function_one_arg_bool + python3 -m pytest --reruns 3 --durations=50 --verbose \ + array_api_tests/test_type_promotion.py::test_elementwise_function_one_arg_type_promotion + python3 -m pytest --reruns 3 --durations=50 --verbose \ + array_api_tests/test_type_promotion.py::test_operator_one_arg_type_promotion + python3 -m pytest --reruns 3 --durations=50 --verbose \ + array_api_tests/test_type_promotion.py::test_operator_two_arg_bool_promotion + python3 -m pytest --reruns 3 --durations=50 --verbose \ + array_api_tests/test_type_promotion.py::test_operator_two_arg_promoted_promotion + python3 -m pytest --reruns 3 --durations=50 --verbose \ + array_api_tests/test_type_promotion.py::test_operator_inplace_two_arg_promoted_promotion diff --git a/ci/docker/install/requirements b/ci/docker/install/requirements index 21f10b92cba8..7b8e2d033591 100644 --- a/ci/docker/install/requirements +++ b/ci/docker/install/requirements @@ -41,6 +41,7 @@ pytest-env==0.6.2 pytest-cov==2.10.1 pytest-xdist==2.1.0 pytest-timeout==1.4.2 +pytest-rerunfailures==10.2 flaky==3.7.0 setuptools==49.6.0 # https://github.com/pypa/setuptools/issues/2352 wheel diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index d89994d972bc..8ffb49d24141 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -872,10 +872,27 @@ unittest_array_api_standardization() { # when cython is enabled export MXNET_ENABLE_CYTHON=0 export DMLC_LOG_STACK_TRACE_DEPTH=100 - python3 -m pytest --durations=50 --cov-report xml:tests_api.xml --verbose \ + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose array_api_tests/test_creation_functions.py + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose array_api_tests/test_indexing.py + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose array_api_tests/test_elementwise_functions.py + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose array_api_tests/test_constants.py + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose array_api_tests/test_broadcasting.py + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose \ array_api_tests/test_type_promotion.py::test_elementwise_function_two_arg_bool_type_promotion - python3 -m pytest --durations=50 --cov-report xml:tests_api.xml --verbose array_api_tests/test_creation_functions.py - python3 -m pytest --durations=50 --cov-report xml:tests_api.xml --verbose array_api_tests/test_indexing.py + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose \ + array_api_tests/test_type_promotion.py::test_elementwise_function_two_arg_promoted_type_promotion + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose \ + array_api_tests/test_type_promotion.py::test_elementwise_function_one_arg_bool + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose \ + array_api_tests/test_type_promotion.py::test_elementwise_function_one_arg_type_promotion + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose \ + array_api_tests/test_type_promotion.py::test_operator_one_arg_type_promotion + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose \ + array_api_tests/test_type_promotion.py::test_operator_two_arg_bool_promotion + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose \ + array_api_tests/test_type_promotion.py::test_operator_two_arg_promoted_promotion + python3 -m pytest --reruns 3 --durations=50 --cov-report xml:tests_api.xml --verbose \ + array_api_tests/test_type_promotion.py::test_operator_inplace_two_arg_promoted_promotion popd } diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index 69c6a88643ab..e6f40806e273 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -675,7 +675,7 @@ def test_unix_python3_array_api(lib_name) { return ['Python3: Array-API': { node(NODE_LINUX_CPU) { ws('workspace/ut-python3-cpu') { - utils.unpack_and_init(lib_name, mx_lib, true) + utils.unpack_and_init(lib_name, mx_lib, false) python3_ut_array_api('ubuntu_cpu') utils.publish_test_coverage() } diff --git a/ci/jenkins/Jenkinsfile_unix_cpu b/ci/jenkins/Jenkinsfile_unix_cpu index 9681270d8905..22fc536592c2 100644 --- a/ci/jenkins/Jenkinsfile_unix_cpu +++ b/ci/jenkins/Jenkinsfile_unix_cpu @@ -46,7 +46,8 @@ core_logic: { utils.parallel_stage('Tests', [ custom_steps.test_unix_python3_cpu('cpu'), custom_steps.test_unix_python3_onnx_cpu('cpu'), - custom_steps.test_unix_python3_array_api('cpu'), + // TVMOP has issue with NAN, see https://github.com/apache/incubator-mxnet/issues/20729 + custom_steps.test_unix_python3_array_api('cpu_openblas_no_tvm_op'), custom_steps.test_unix_python3_mkl_cpu('cpu_mkl'), custom_steps.test_unix_python3_onednn_cpu('onednn_cpu'), custom_steps.test_unix_python3_onednn_mkl_cpu('onednn_mkl_cpu'), diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 6da2c0641153..e0b20abfcf76 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -3757,6 +3757,8 @@ def ceil(x, out=None, **kwargs): >>> a array(4.) """ + if isinstance(x, NDArray) and _np.issubdtype(x.dtype, _np.integer): + return x return _pure_unary_func_helper(x, _api_internal.ceil, _np.ceil, out=out, **kwargs) @@ -3796,6 +3798,8 @@ def floor(x, out=None, **kwargs): >>> a array(3.) """ + if isinstance(x, NDArray) and _np.issubdtype(x.dtype, _np.integer): + return x return _pure_unary_func_helper(x, _api_internal.floor, _np.floor, out=out, **kwargs) @@ -3941,6 +3945,8 @@ def trunc(x, out=None, **kwargs): >>> np.trunc(a) array([-1., -1., -0., 0., 1., 1., 2.]) """ + if isinstance(x, NDArray) and _np.issubdtype(x.dtype, _np.integer): + return x return _pure_unary_func_helper(x, _api_internal.trunc, _np.trunc, out=out, **kwargs) diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py index 45699f714ed4..1228dac666e8 100644 --- a/python/mxnet/numpy/__init__.py +++ b/python/mxnet/numpy/__init__.py @@ -28,6 +28,7 @@ from .function_base import * # pylint: disable=wildcard-import from .stride_tricks import * # pylint: disable=wildcard-import from .set_functions import * # pylint: disable=wildcard-import +from .type_functions import * # pylint: disable=wildcard-import from .io import * # pylint: disable=wildcard-import from .arrayprint import * # pylint: disable=wildcard-import diff --git a/python/mxnet/numpy/fallback.py b/python/mxnet/numpy/fallback.py index 83bf67372517..c8fc7fbaf7f8 100644 --- a/python/mxnet/numpy/fallback.py +++ b/python/mxnet/numpy/fallback.py @@ -94,7 +94,6 @@ 'pv', 'rate', 'real', - 'result_type', 'roots', 'searchsorted', 'select', diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index d019cb73c30e..5a2ac27f7e4c 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -584,6 +584,8 @@ def _get_np_boolean_indexing(self, key, ndim, shape): remaining_dims = shape[key_ndim:] data = _reshape_view(self, -1, *remaining_dims) key = _reshape_view(key, -1) + if data.size == 0 and key.size == 0: + return data return _reshape_view(_npi.boolean_mask(data, key), -1, *remaining_dims) def _set_np_boolean_indexing(self, key, value): @@ -13285,34 +13287,50 @@ def asarray(obj, dtype=None, device=None, copy=None): Examples -------- - >>> a = np.arange(4).reshape(2,2) - >>> a - array([[0, 1], - [2, 3]]) - >>> np.diagonal(a) - array([0, 3]) - >>> np.diagonal(a, 1) - array([1]) + >>> np.asarray([1, 2, 3]) + array([1., 2., 3.]) - >>> a = np.arange(8).reshape(2,2,2) - >>>a - array([[[0, 1], - [2, 3]], - [[4, 5], - [6, 7]]]) - >>> np.diagonal(a, 0, 0, 1) - array([[0, 6], - [1, 7]]) + >>> np.asarray([[1, 2], [3, 4]], dtype=np.int32) + array([[1, 2], + [3, 4]], dtype=int32) + + >>> np.asarray([1.2], device=mx.gpu()) + array([1.2], device=gpu(0)) """ if isinstance(obj, numeric_types): dtype = dtype_from_number(obj) if dtype is None else dtype obj = _np.asarray(obj, dtype=dtype) elif isinstance(obj, _np.ndarray): - dtype = obj.dtype if dtype is None else dtype + if is_np_default_dtype(): + dtype = obj.dtype if dtype is None else dtype + else: + dtype = _np.float32 if dtype is None or obj.dtype is _np.float64 else dtype elif isinstance(obj, ndarray): - dtype = obj.dtype if dtype is None else dtype - array = _as_mx_np_array(obj, device=device, zero_copy=copy) - return array.astype(dtype) + if dtype is not None: + obj = obj.astype(dtype, copy=copy) + if device is not None: + obj = obj.to_device(device) + return obj + elif hasattr(obj, '__dlpack__'): + return from_dlpack(obj) + else: + if dtype is None: + default_dtype = _np.float64 if is_np_default_dtype() else _np.float32 + dtype = obj.dtype if hasattr(obj, "dtype") else default_dtype + try: + obj = _np.array(obj, dtype=dtype) + except Exception as e: + # printing out the error raised by official NumPy's array function + # for transparency on users' side + raise TypeError('{}'.format(str(e))) + if device is None: + device = current_device() + ret = empty(obj.shape, dtype=dtype, device=device) + if len(obj.shape) == 0: + ret[()] = obj + else: + ret[:] = obj + return ret # pylint: disable=redefined-outer-name diff --git a/python/mxnet/numpy/type_functions.py b/python/mxnet/numpy/type_functions.py new file mode 100644 index 000000000000..bf95f1cc8ef7 --- /dev/null +++ b/python/mxnet/numpy/type_functions.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Type functions for the numpy module.""" + +from typing import NamedTuple + +import numpy as onp +from .multiarray import ndarray +from .utils import _type_promotion_table + + +__all__ = ['can_cast', 'finfo', 'iinfo', 'result_type'] + +class finfo_obj(NamedTuple): + bits: int + eps: float + max: float + min: float + smallest_normal: float + + +class iinfo_obj(NamedTuple): + bits: int + max: int + min: int + + +def can_cast(from_, to): + """ + Returns True if cast between data types can occur according to + the casting rule. If from is a scalar or array scalar, + also returns True if the scalar value can be cast without + overflow or truncation to an integer. + Parameters + ---------- + from_ : dtype, ndarray or scalar + Data type, scalar, or array to cast from. + to : dtype + Data type to cast to. + Returns + ------- + out : bool + True if cast can occur according to the casting rule. + """ + if isinstance(from_, ndarray): + from_ = from_.asnumpy() + return onp.can_cast(from_, to) + + +def finfo(dtype): + """ + Machine limits for floating-point data types. + Notes + ----- + `finfo` is a standard API in + https://data-apis.org/array-api/latest/API_specification/data_type_functions.html#finfo-type + instead of an official NumPy operator. + Parameters + ---------- + dtype : ndarray, float or dtype + Kind of floating point data-type about which to get information. + Returns + ------- + out : finfo object + an object having the following attributes: + - bits : int + number of bits occupied by the floating-point data type. + - eps : float + difference between 1.0 and the next smallest representable floating-point + number larger than 1.0 according to the IEEE-754 standard. + - max : float + largest representable number. + - min : float + smallest representable number. + - smallest_normal : float + smallest positive floating-point number with full precision. + """ + f_info = onp.finfo(dtype) + return finfo_obj(f_info.bits, float(f_info.eps), + float(f_info.max), float(f_info.min), float(f_info.tiny)) + + +def iinfo(dtype): + """ + Machine limits for floating-point data types. + Notes + ----- + `iinfo` is a standard API in + https://data-apis.org/array-api/latest/API_specification/data_type_functions.html#iinfo-type + instead of an official NumPy operator. + Parameters + ---------- + dtype : ndarray, integer or dtype + The kind of integer data type to get information about. + Returns + ------- + out : iinfo object + an object having the following attributes: + - bits : int + number of bits occupied by the type + - max : int + largest representable number. + - min : int + smallest representable number. + """ + i_info = onp.iinfo(dtype) + return iinfo_obj(i_info.bits, i_info.max, i_info.min) + + +def _get_dtype(array_or_dtype): + """Utility function for result_type""" + if isinstance(array_or_dtype, (ndarray, onp.ndarray)): + return array_or_dtype.dtype + elif isinstance(array_or_dtype, onp.dtype): + return array_or_dtype + else: + raise ValueError("Inputs of result_type must be ndarrays or dtypes") + + +def result_type(*arrays_and_dtypes): + """ + Returns the dtype that results from applying the type promotion rules to the arguments. + Notes + ----- + `result_type` is a standard API in + https://data-apis.org/array-api/latest/API_specification/data_type_functions.html#result-type-arrays-and-dtypes + instead of an official NumPy operator. + Parameters + ---------- + arrays_and_dtypes : mixed ndarrays and dtypes + an arbitrary number of input arrays and/or dtypes. + Returns + ------- + out : dtype + the dtype resulting from an operation involving the input arrays and dtypes. + """ + if len(arrays_and_dtypes) > 0: + ret = _get_dtype(arrays_and_dtypes[0]) + for d in arrays_and_dtypes[1:]: + dd = _get_dtype(d) + if (ret, dd) in _type_promotion_table: + ret = _type_promotion_table[ret, dd] + elif (dd, ret) in _type_promotion_table: + ret = _type_promotion_table[dd, ret] + else: + raise TypeError("Unknown type promotion between {} and {}".format(ret, dd)) + return ret + raise ValueError("at least one array or dtype is required") diff --git a/python/mxnet/numpy/utils.py b/python/mxnet/numpy/utils.py index 15b83c7f2b73..21fe1e299d2e 100644 --- a/python/mxnet/numpy/utils.py +++ b/python/mxnet/numpy/utils.py @@ -23,25 +23,26 @@ __all__ = ['float16', 'float32', 'float64', 'uint8', 'int32', 'int8', 'int64', 'int16', 'uint16', 'uint32', 'uint64', - 'bool', 'bool_', 'pi', 'inf', 'nan', 'PZERO', 'NZERO', 'newaxis', 'finfo', + 'bool', 'bool_', 'pi', 'inf', 'nan', 'PZERO', 'NZERO', 'newaxis', 'e', 'NINF', 'PINF', 'NAN', 'NaN', - '_STR_2_DTYPE_', '_DTYPE_2_STR_'] + '_STR_2_DTYPE_', '_DTYPE_2_STR_', '_type_promotion_table', + 'integer_dtypes', 'floating_dtypes', 'boolean_dtypes', 'numeric_dtypes'] py_bool = bool -float16 = onp.float16 -float32 = onp.float32 -float64 = onp.float64 -uint8 = onp.uint8 -int32 = onp.int32 -int8 = onp.int8 -int64 = onp.int64 -bool_ = onp.bool_ -bool = onp.bool -int16 = onp.int16 -uint16 = onp.uint16 -uint32 = onp.uint32 -uint64 = onp.uint64 +float16 = onp.dtype(onp.float16) +float32 = onp.dtype(onp.float32) +float64 = onp.dtype(onp.float64) +uint8 = onp.dtype(onp.uint8) +int32 = onp.dtype(onp.int32) +int8 = onp.dtype(onp.int8) +int64 = onp.dtype(onp.int64) +bool_ = onp.dtype(onp.bool_) +bool = onp.dtype(onp.bool) +int16 = onp.dtype(onp.int16) +uint16 = onp.dtype(onp.uint16) +uint32 = onp.dtype(onp.uint32) +uint64 = onp.dtype(onp.uint64) pi = onp.pi inf = onp.inf @@ -55,7 +56,6 @@ NaN = onp.NaN newaxis = None -finfo = onp.finfo _STR_2_DTYPE_ = {'float16': float16, 'float32': float32, 'float64': float64, 'float': float64, 'int8': int8, 'int16': int16, 'int32': int32, 'int64': int64, 'int': int64, @@ -77,3 +77,125 @@ def _get_np_op(name): if op is not None: return op raise ValueError('Operator `{}` is not supported by `mxnet.numpy`.'.format(name)) + + +_type_promotion_table = { + # signed integer type promotion + (int8, int8): int8, + (int8, int16): int16, + (int8, int32): int32, + (int8, int64): int64, + (int16, int16): int16, + (int16, int32): int32, + (int16, int64): int64, + (int32, int32): int32, + (int32, int64): int64, + (int64, int64): int64, + # unsigned integer type promotion + (uint8, uint8): uint8, + (uint8, uint16): uint16, + (uint8, uint32): uint32, + (uint8, uint64): uint64, + (uint16, uint16): uint16, + (uint16, uint32): uint32, + (uint16, uint64): uint64, + (uint32, uint32): uint32, + (uint32, uint64): uint64, + (uint64, uint64): uint64, + # mixed signed and unsigned integer type promotion + (int8, uint8): int16, + (int8, uint16): int32, + (int8, uint32): int64, + (int16, uint8): int16, + (int16, uint16): int32, + (int16, uint32): int64, + (int32, uint8): int32, + (int32, uint16): int32, + (int32, uint32): int64, + (int64, uint8): int64, + (int64, uint16): int64, + (int64, uint32): int64, + # float type promotion + (float16, float16): float16, + (float16, float32): float32, + (float16, float64): float64, + (float32, float32): float32, + (float32, float64): float64, + (float64, float64): float64, + # bool type promotion + (bool, bool): bool, + # mixed integer and float16 type promotion + (int8, float16): float16, + (int16, float16): float16, + (int32, float16): float16, + (int64, float16): float16, + (uint8, float16): float16, + (uint16, float16): float16, + (uint32, float16): float16, + (uint64, float16): float16, + # mixed integer and float16 type promotion + (int8, float32): float32, + (int16, float32): float32, + (int32, float32): float32, + (int64, float32): float32, + (uint8, float32): float32, + (uint16, float32): float32, + (uint32, float32): float32, + (uint64, float32): float32, + # mixed integer and float32 type promotion + (int8, float32): float32, + (int16, float32): float32, + (int32, float32): float32, + (int64, float32): float32, + (uint8, float32): float32, + (uint16, float32): float32, + (uint32, float32): float32, + (uint64, float32): float32, + # mixed integer and float64 type promotion + (int8, float64): float64, + (int16, float64): float64, + (int32, float64): float64, + (int64, float64): float64, + (uint8, float64): float64, + (uint16, float64): float64, + (uint32, float64): float64, + (uint64, float64): float64, + # mixed bool and other type promotion + (bool, int8): int8, + (bool, int16): int16, + (bool, int32): int32, + (bool, int64): int64, + (bool, uint8): uint8, + (bool, uint16): uint16, + (bool, uint32): uint32, + (bool, uint64): uint64, + (bool, float16): float16, + (bool, float32): float32, + (bool, float64): float64, +} + +integer_dtypes = [ + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +] + +floating_dtypes = [ + float16, + float32, + float64, +] + +numeric_dtypes = [ + *integer_dtypes, + *floating_dtypes, +] + +boolean_dtypes = [ + bool_, +] diff --git a/python/mxnet/util.py b/python/mxnet/util.py index cf2c2a95e628..f99dfd07413e 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -23,7 +23,7 @@ from struct import calcsize from .base import (_LIB, check_call, c_str, py_str, - numeric_types, integer_types, + numeric_types, integer_types, long, _MAX_VALUE_64_BIT_UNSIGNED_, _MAX_VALUE_64_BIT_SIGNED_, _MAX_VALUE_FLOAT32_REPRESENT_) @@ -1339,7 +1339,7 @@ def dtype_from_number(number): assert isinstance(number, numeric_types),\ "The input number should be either int for float types" import numpy as _np - if isinstance(number, integer_types): + if isinstance(number, (int, long)): if number > _MAX_VALUE_64_BIT_UNSIGNED_: raise OverflowError("Integer out of bounds") if number > _MAX_VALUE_64_BIT_SIGNED_: @@ -1348,8 +1348,14 @@ def dtype_from_number(number): return _np.int64 else: return _np.int32 - else: - if abs(number) > _MAX_VALUE_FLOAT32_REPRESENT_: + elif isinstance(number, float): + if abs(number) > _MAX_VALUE_FLOAT32_REPRESENT_ or \ + ((not _np.isnan(number)) and \ + (_np.float32(number) == int(number)) and \ + (number != int(number))): return _np.float64 else: return _np.float64 if is_np_default_dtype() else _np.float32 + elif isinstance(number, _np.generic): + return number.dtype + raise TypeError('type {} not supported'.format(str(type(number)))) diff --git a/src/common/cuda/rtc/util-inl.h b/src/common/cuda/rtc/util-inl.h index f294aa0ef2eb..66e23518b865 100644 --- a/src/common/cuda/rtc/util-inl.h +++ b/src/common/cuda/rtc/util-inl.h @@ -37,6 +37,10 @@ using uint8 = unsigned char; using int8 = char; using int32 = int; using int64 = long long; +using int16 = short; +using uint16 = unsigned short; +using uint32 = unsigned int; +using uint64 = unsigned long long; static_assert(sizeof(float32) == 4, "Size of float32 is expected to be 4B"); static_assert(sizeof(float64) == 8, "Size of float64 is expected to be 8B"); @@ -45,6 +49,10 @@ static_assert(sizeof(uint8) == 1, "Size of uint8 is expected to be 1B"); static_assert(sizeof(int8) == 1, "Size of int8 is expected to be 1B"); static_assert(sizeof(int32) == 4, "Size of int32 is expected to be 4B"); static_assert(sizeof(int64) == 8, "Size of int64 is expected to be 8B"); +static_assert(sizeof(int16) == 2, "Size of int16 is expected to be 2B"); +static_assert(sizeof(uint16) == 2, "Size of uint16 is expected to be 2B"); +static_assert(sizeof(uint32) == 4, "Size of uint32 is expected to be 4B"); +static_assert(sizeof(uint64) == 8, "Size of uint64 is expected to be 8B"); )code" #if MSHADOW_INT64_TENSOR_SIZE == 1 @@ -129,7 +137,11 @@ struct true_type { // is_integral template struct is_integral : false_type {}; template <> struct is_integral : true_type {}; +template <> struct is_integral : true_type {}; +template <> struct is_integral : true_type {}; +template <> struct is_integral : true_type {}; template <> struct is_integral : true_type {}; +template <> struct is_integral : true_type {}; template <> struct is_integral : true_type {}; template <> struct is_integral : true_type {}; template <> struct is_integral : true_type {}; @@ -138,6 +150,9 @@ template <> struct is_integral : true_type {}; // is_unsigned template struct is_unsigned : false_type {}; template <> struct is_unsigned : true_type {}; +template <> struct is_unsigned : true_type {}; +template <> struct is_unsigned : true_type {}; +template <> struct is_unsigned : true_type {}; template <> struct is_unsigned : true_type {}; template <> struct is_unsigned : true_type {}; @@ -211,19 +226,141 @@ struct mixed_type_helper::value>:: template struct mixed_type_helper::value && is_integral::value && + is_unsigned::value && + is_unsigned::value && !is_same::value && - sizeof(T) <= sizeof(U)>::type> { + sizeof(T) < sizeof(U)>::type> { + using type = U; +}; + +template +struct mixed_type_helper::value && + is_integral::value && + !is_unsigned::value && + !is_unsigned::value && + !is_same::value && + sizeof(T) < sizeof(U)>::type> { + using type = U; +}; + +template +struct mixed_type_helper::value && + is_integral::value && + is_unsigned::value && + !is_unsigned::value && + !is_same::value && + sizeof(T) < sizeof(U)>::type> { + using type = U; +}; + +template +struct mixed_type_helper::value && + is_integral::value && + is_unsigned::value && + is_unsigned::value && + !is_same::value && + sizeof(T) < sizeof(U)>::type> { using type = U; }; template struct mixed_type_helper::value && is_integral::value && + !is_unsigned::value && + !is_unsigned::value && !is_same::value && sizeof(T) < sizeof(U)>::type> { using type = U; }; +template +struct mixed_type_helper::value && + is_integral::value && + is_unsigned::value && + !is_unsigned::value && + !is_same::value && + sizeof(T) < sizeof(U)>::type> { + using type = U; +}; + +template +struct mixed_type_helper::value && + is_integral::value && + !is_same::value && + is_same::value>::type> { + using type = U; +}; + +template<> +struct mixed_type_helper { + using type = int16; +}; + +template<> +struct mixed_type_helper { + using type = int16; +}; + +template<> +struct mixed_type_helper { + using type = int32; +}; + +template<> +struct mixed_type_helper { + using type = int32; +}; + +template<> +struct mixed_type_helper { + using type = int64; +}; + +template<> +struct mixed_type_helper { + using type = int64; +}; + +template<> +struct mixed_type_helper { + using type = int32; +}; + +template<> +struct mixed_type_helper { + using type = int32; +}; + +template<> +struct mixed_type_helper { + using type = int64; +}; + +template<> +struct mixed_type_helper { + using type = int64; +}; + +template<> +struct mixed_type_helper { + using type = int64; +}; + +template<> +struct mixed_type_helper { + using type = int64; +}; + +template<> +struct mixed_type_helper { + using type = index_t; +}; + +template<> +struct mixed_type_helper { + using type = index_t; +}; + template struct mixed_type_helper::value && sizeof(T) < sizeof(bool_t)>::type> { @@ -242,6 +379,13 @@ struct mixed_type_helper::value && using type = T; }; +template +struct mixed_type_helper::value && + !is_same::value && + sizeof(T) == sizeof(bool_t)>::type> { + using type = T; +}; + template struct multi_mixed_type_helper; @@ -472,11 +616,31 @@ template<> __device__ inline uint8 MinValue(void) { return 0; } +/*! \brief minimum value of uint16 */ +template<> +__device__ inline uint16 MinValue(void) { + return 0; +} +/*! \brief minimum value of uint32 */ +template<> +__device__ inline uint32 MinValue(void) { + return 0; +} +/*! \brief minimum value of uint64 */ +template<> +__device__ inline uint64 MinValue(void) { + return 0; +} /*! \brief minimum value of int8_t */ template<> __device__ inline int8 MinValue(void) { return -128; } +/*! \brief minimum value of int16 */ +template<> +__device__ inline int16 MinValue(void) { + return -32768; +} /*! \brief minimum value of int32 */ template<> __device__ inline int32 MinValue(void) { @@ -538,11 +702,31 @@ template<> __device__ inline uint8 MaxValue(void) { return 255; } +/*! \brief maximum value of uint16 */ +template<> +__device__ inline uint16 MaxValue(void) { + return 65535; +} +/*! \brief maximum value of uint32 */ +template<> +__device__ inline uint32 MaxValue(void) { + return 4294967295; +} +/*! \brief maximum value of uint64 */ +template<> +__device__ inline uint64 MaxValue(void) { + return 18446744073709551615LL; +} /*! \brief maximum value of int8 */ template<> __device__ inline int8 MaxValue(void) { return 127; } +/*! \brief maximum value of int16 */ +template<> +__device__ inline int16 MaxValue(void) { + return 32767; +} /*! \brief maximum value of int32 */ template<> __device__ inline int32 MaxValue(void) { diff --git a/src/common/utils.cc b/src/common/utils.cc index f400093cc9b5..639ded4ec80e 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -117,6 +117,14 @@ MShadowTypeInfo mshadow_type_info(const int type_flag) { return MShadowTypeInfo("float16", 2, sizeof(float)); case kUint8: return MShadowTypeInfo("uint8", sizeof(uint8_t), sizeof(index_t)); + case kUint16: + return MShadowTypeInfo("uint16", sizeof(uint16_t)); + case kUint32: + return MShadowTypeInfo("uint32", sizeof(uint32_t)); + case kUint64: + return MShadowTypeInfo("uint64", sizeof(uint64_t)); + case kInt16: + return MShadowTypeInfo("int16", sizeof(int16_t)); case kInt32: return MShadowTypeInfo("int32", sizeof(int32_t)); case kInt8: diff --git a/src/common/utils.h b/src/common/utils.h index 15e676c816c9..7cf54ca2b4fa 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -901,11 +901,53 @@ inline bool is_float(const int dtype) { } inline bool is_int(const int dtype) { - return dtype == mshadow::kUint8 || dtype == mshadow::kInt8 || dtype == mshadow::kInt32 || + return dtype == mshadow::kUint8 || dtype == mshadow::kInt8 || dtype == mshadow::kUint16 || + dtype == mshadow::kInt16 || dtype == mshadow::kUint32 || dtype == mshadow::kInt32 || + dtype == mshadow::kUint64 || dtype == mshadow::kInt64; +} + +inline bool is_signed_int(const int dtype) { + return dtype == mshadow::kInt8 || dtype == mshadow::kInt16 || dtype == mshadow::kInt32 || dtype == mshadow::kInt64; } -inline int get_more_precise_type(const int type1, const int type2) { +inline bool is_unsigned_int(const int dtype) { + return dtype == mshadow::kUint8 || dtype == mshadow::kUint16 || dtype == mshadow::kUint32 || + dtype == mshadow::kUint64; +} + +static int bits_of(const int type_flag) { + switch (type_flag) { + case mshadow::kFloat32: + return sizeof(float) * CHAR_BIT; + case mshadow::kFloat64: + return sizeof(double) * CHAR_BIT; + case mshadow::kUint8: + return sizeof(uint8_t) * CHAR_BIT; + case mshadow::kInt32: + return sizeof(int32_t) * CHAR_BIT; + case mshadow::kInt8: + return sizeof(int8_t) * CHAR_BIT; + case mshadow::kInt64: + return sizeof(int64_t) * CHAR_BIT; + case mshadow::kBool: + return sizeof(bool) * CHAR_BIT; + case mshadow::kInt16: + return sizeof(int16_t) * CHAR_BIT; + case mshadow::kUint16: + return sizeof(uint16_t) * CHAR_BIT; + case mshadow::kUint32: + return sizeof(uint32_t) * CHAR_BIT; + case mshadow::kUint64: + return sizeof(uint64_t) * CHAR_BIT; + default: { + LOG(FATAL) << "Unknown type_flag=" << type_flag; + return -1; + } + } +} + +inline int type_promotion(const int type1, const int type2) { if (type1 == type2) return type1; if (is_float(type1) && is_float(type2)) { @@ -919,27 +961,74 @@ inline int get_more_precise_type(const int type1, const int type2) { } else if (is_float(type1) || is_float(type2)) { return is_float(type1) ? type1 : type2; } - if (type1 == mshadow::kInt64 || type2 == mshadow::kInt64) { - return mshadow::kInt64; - } - if (type1 == mshadow::kInt32 || type2 == mshadow::kInt32) { - return mshadow::kInt32; - } - CHECK(!((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) || - (type1 == mshadow::kInt8 && type2 == mshadow::kUint8))) - << "1 is UInt8 and 1 is Int8 should not get here"; - if (type1 == mshadow::kUint8 || type2 == mshadow::kUint8) { + if (is_signed_int(type1) && is_signed_int(type2)) { + if (type1 == mshadow::kInt64 || type2 == mshadow::kInt64) { + return mshadow::kInt64; + } + if (type1 == mshadow::kInt32 || type2 == mshadow::kInt32) { + return mshadow::kInt32; + } + if (type1 == mshadow::kInt16 || type2 == mshadow::kInt16) { + return mshadow::kInt16; + } + return mshadow::kInt8; + } else if (is_unsigned_int(type1) && is_unsigned_int(type2)) { + if (type1 == mshadow::kUint64 || type2 == mshadow::kUint64) { + return mshadow::kUint64; + } + if (type1 == mshadow::kUint32 || type2 == mshadow::kUint32) { + return mshadow::kUint32; + } + if (type1 == mshadow::kUint16 || type2 == mshadow::kUint16) { + return mshadow::kUint16; + } return mshadow::kUint8; + } else if (type1 == mshadow::kBool) { + return type2; + } else if (type2 == mshadow::kBool) { + return type1; + } else if (is_unsigned_int(type1) || is_unsigned_int(type2)) { + if (bits_of(type1) < bits_of(type2)) { + if (type1 == mshadow::kInt8 && type2 == mshadow::kUint16) { + return mshadow::kInt32; + } else if (type1 == mshadow::kInt8 && type2 == mshadow::kUint32) { + return mshadow::kInt64; + } else if (type1 == mshadow::kInt16 && type2 == mshadow::kUint32) { + return mshadow::kInt64; + } else if (type2 == mshadow::kUint64) { + LOG(FATAL) << "Unsupported type promotions between " << mshadow::dtype_string(type1) + << " and " << mshadow::dtype_string(type2); + } else { + return type2; + } + } else if (bits_of(type2) < bits_of(type1)) { + if (type2 == mshadow::kInt8 && type1 == mshadow::kUint16) { + return mshadow::kInt32; + } else if (type2 == mshadow::kInt8 && type1 == mshadow::kUint32) { + return mshadow::kInt64; + } else if (type2 == mshadow::kInt16 && type1 == mshadow::kUint32) { + return mshadow::kInt64; + } else if (type1 == mshadow::kUint64) { + LOG(FATAL) << "Unsupported type promotions between " << mshadow::dtype_string(type1) + << " and " << mshadow::dtype_string(type2); + } else { + return type1; + } + } else { + if (type1 == mshadow::kUint8 || type2 == mshadow::kUint8) { + return mshadow::kInt16; + } + if (type1 == mshadow::kUint16 || type2 == mshadow::kUint16) { + return mshadow::kInt32; + } + if (type1 == mshadow::kUint32 || type2 == mshadow::kUint32) { + return mshadow::kInt64; + } + } } - return mshadow::kInt8; -} - -inline int np_binary_out_infer_type(const int type1, const int type2) { - if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) || - (type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) { - return mshadow::kInt32; - } - return get_more_precise_type(type1, type2); + LOG(FATAL) << "Unsupported type promotions between " << mshadow::dtype_string(type1) << " and " + << mshadow::dtype_string(type2); + return -1; } inline const std::string NodeAttrsGetProfilerScope(const nnvm::NodeAttrs& attrs) { diff --git a/src/ndarray/ndarray_function-inl.h b/src/ndarray/ndarray_function-inl.h index c1d81191dbee..8101bf2a624f 100644 --- a/src/ndarray/ndarray_function-inl.h +++ b/src/ndarray/ndarray_function-inl.h @@ -402,7 +402,7 @@ void EvalRandom(const real_t& mu, template <> void Eval(const real_t& rhs, TBlob* ret, RunContext ctx) { mshadow::Stream* s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH_WITH_BOOL( + MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL( ret->type_flag_, DType, { ret->FlatTo2D(s) = DType(rhs); }); } diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index f6189f939131..3313014ec908 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -46,7 +46,7 @@ void Copy(const TBlob& from, RunContext ctx) { CHECK_EQ(to->type_flag_, from.type_flag_) << "Source and target must have the same data type when copying across devices."; - MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { + MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(to->type_flag_, DType, { mshadow::Copy(to->FlatTo1D(), from.FlatTo1D(), ctx.get_stream()); }); } @@ -59,7 +59,7 @@ void Copy(const TBlob& from, RunContext ctx) { CHECK_EQ(to->type_flag_, from.type_flag_) << "Source and target must have the same data type when copying across devices."; - MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { + MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(to->type_flag_, DType, { mshadow::Copy(to->FlatTo1D(), from.FlatTo1D(), ctx.get_stream()); }); } diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index f3ba7f9f638f..b9307ea7d1dd 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -133,7 +133,7 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, const_cast(out).Init(s); // do the copy - MSHADOW_TYPE_SWITCH_WITH_BOOL(data.dtype(), DType, { + MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(data.dtype(), DType, { size_t input_size = data.shape().Size(); size_t col_size = input_size / idx_size; mshadow::Stream* stream = ctx.get_stream(); diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 9a14794a47da..41f1aa5d1828 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -239,12 +239,7 @@ struct floor_divide : public mxnet_op::tunable { typename std::enable_if::value && std::is_integral::value, int>::type = 0> MSHADOW_XINLINE static DType Map(DType a, DType b) { - DType c = static_cast(::floor(a / b)); - if ((c * a != b) && ((a < 0) != (b < 0))) { - return DType(c - 1); - } else { - return c; - } + return static_cast(::floor(static_cast(a) / static_cast(b))); } MSHADOW_XINLINE static bool Map(bool a, bool b) { @@ -270,12 +265,7 @@ struct rfloor_divide : public mxnet_op::tunable { typename std::enable_if::value && std::is_integral::value, int>::type = 0> MSHADOW_XINLINE static DType Map(DType a, DType b) { - DType c = static_cast(::floor(b / a)); - if ((c * a != b) && ((a < 0) != (b < 0))) { - return DType(c - 1); - } else { - return c; - } + return static_cast(::floor(static_cast(b) / static_cast(a))); } MSHADOW_XINLINE static bool Map(bool a, bool b) { @@ -819,7 +809,15 @@ MXNET_BINARY_MATH_OP(bitwise_or, static_cast(a) | static_cast( #endif /*! \brief used for generate element of bitwise_left_shift */ -MXNET_BINARY_MATH_OP(bitwise_left_shift, static_cast(a) << static_cast(b)); +struct bitwise_left_shift : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (static_cast(b) >= (sizeof(DType) * CHAR_BIT)) { + return DType(0); + } + return static_cast(a) << static_cast(b); + } +}; MXNET_BINARY_MATH_OP(bitwise_left_shift_grad, math::pow(2.0f, static_cast(b))); @@ -834,7 +832,19 @@ MXNET_BINARY_MATH_OP(rbitwise_left_shift_grad, math::log(2.0f)); /*! \brief used for generate element of bitwise_right_shift */ -MXNET_BINARY_MATH_OP(bitwise_right_shift, static_cast(a) >> static_cast(b)); +struct bitwise_right_shift : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (static_cast(b) >= (sizeof(DType) * CHAR_BIT)) { + if (a < 0) { + return DType(-1); + } else { + return DType(0); + } + } + return static_cast(a) >> static_cast(b); + } +}; MXNET_BINARY_MATH_OP(bitwise_right_shift_grad, math::pow(0.5f, static_cast(b))); @@ -995,10 +1005,16 @@ struct mod : public mxnet_op::tunable { } else if (b < DType(0)) { if (a < DType(0)) { return DType(-::fmod(-static_cast(a), -static_cast(b))); + } else if (a == DType(0)) { + return -DType(0); } else { - return DType( + DType ret = DType( ::fmod(static_cast(a), -static_cast(b)) + (::fmod(static_cast(a), -static_cast(b)) != DType(0) ? b : DType(0))); + if (ret == 0) { + return -ret; + } + return ret; } } else { if (a < DType(0)) { diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 09e42481a66b..c8a00fbfeefa 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -419,6 +419,60 @@ struct AccType { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: { \ + LOG(FATAL) << "This operation only support " \ + "integer and bool types, not float32"; \ + } break; \ + case mshadow::kFloat64: { \ + LOG(FATAL) << "This operation only support " \ + "integer and bool types, not float64"; \ + } break; \ + case mshadow::kFloat16: { \ + LOG(FATAL) << "This operation only support " \ + "integer and boo; types, not float16"; \ + } break; \ + case mshadow::kUint8: { \ + typedef uint8_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kInt8: { \ + typedef int8_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kInt32: { \ + typedef int32_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kInt64: { \ + typedef int64_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kInt16: { \ + typedef int16_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kUint16: { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kUint32: { \ + typedef uint32_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kUint64: { \ + typedef uint64_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kBool: { \ + typedef bool DType; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + #define MXNET_INT_TYPE_SWITCH_EXT(type, DType, ...) \ switch (type) { \ case mshadow::kFloat32: { \ @@ -466,8 +520,8 @@ struct AccType { { __VA_ARGS__ } \ } break; \ case mshadow::kBool: { \ - typedef bool DType; \ - { __VA_ARGS__ } \ + LOG(FATAL) << "This operation only support " \ + "integer types, not bool type"; \ } break; \ default: \ LOG(FATAL) << "Unknown type enum " << type; \ diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index da40fe4044e7..29fe12150ff0 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -116,7 +116,7 @@ void MixedIntRealBinaryElemwiseCompute(const OpContext& ctx, if (size == 0) return; - MXNET_INT_TYPE_SWITCH(rhs.type_flag_, IType, { + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(rhs.type_flag_, IType, { MXNET_ASSIGN_REQ_SWITCH(req, Req, { Kernel, xpu>::Launch( s, size, out.dptr(), rhs.dptr(), lhs.dptr()); @@ -125,7 +125,88 @@ void MixedIntRealBinaryElemwiseCompute(const OpContext& ctx, }); } -template +template +void MixedIntBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const TBlob& lhs, + const TBlob& rhs, + const TBlob& out, + const OpReqType req) { + using namespace mshadow; + using namespace mxnet_op; + + Stream* s = ctx.get_stream(); + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH_EXT(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + MXNET_INT_TYPE_SWITCH_EXT(out.type_flag_, DType, { + const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), temp_tblob.Size()) + + DataType::kLanes - 1) / + DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch( + s, size, out.dptr(), lhs.dptr(), temp_tblob.dptr()); + } + }); + }); + } else if (rhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH_EXT(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + MXNET_INT_TYPE_SWITCH_EXT(out.type_flag_, DType, { + const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), temp_tblob.Size(), rhs.Size()) + + DataType::kLanes - 1) / + DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch( + s, size, out.dptr(), temp_tblob.dptr(), rhs.dptr()); + } + }); + }); + } else { + TBlob temp_tblob_l; + TBlob temp_tblob_r; + MXNET_INT_TYPE_SWITCH_EXT(out.type_flag_, OType, { + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(lhs.Size() + rhs.Size()), s); + TBlob temp_tblob = TBlob(workspace); + temp_tblob_l = TBlob(reinterpret_cast(temp_tblob.dptr_), + lhs.shape_, + temp_tblob.dev_mask(), + temp_tblob.dev_id()); + temp_tblob_r = TBlob(reinterpret_cast(temp_tblob.dptr_) + lhs.Size() + 1, + rhs.shape_, + temp_tblob.dev_mask(), + temp_tblob.dev_id()); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob_l}); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob_r}); + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + MXNET_INT_TYPE_SWITCH_EXT(out.type_flag_, DType, { + const size_t size = + (ElemwiseBinaryOp::minthree(out.Size(), temp_tblob_l.Size(), temp_tblob_r.Size()) + + DataType::kLanes - 1) / + DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch( + s, size, out.dptr(), temp_tblob_l.dptr(), temp_tblob_r.dptr()); + } + }); + }); + } +} + +template void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -152,7 +233,7 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs, MixedIntRealBinaryElemwiseCompute(ctx, rhs, lhs, out, req[0]); } } else { - PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); + MixedIntBinaryElemwiseCompute(attrs, ctx, lhs, rhs, out, req[0]); } } @@ -250,7 +331,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, int ndim = BinaryBroadcastShapeCompact( lhs.shape_, rhs.shape_, out.shape_, &new_lshape, &new_rshape, &new_oshape); if (!ndim) { - MixedBinaryElemwiseCompute(attrs, ctx, inputs, req, outputs); + MixedBinaryElemwiseCompute(attrs, ctx, inputs, req, outputs); } else { mshadow::Stream* s = ctx.get_stream(); if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { @@ -270,7 +351,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); if (lhs.type_flag_ == out.type_flag_) { MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, LType, { - MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(rhs.type_flag_, RType, { mxnet_op::Kernel, xpu>::template LaunchEx(s, new_oshape.Size(), @@ -285,7 +366,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, }); } else { MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, RType, { - MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(lhs.type_flag_, LType, { mxnet_op::Kernel, xpu>::template LaunchEx(s, new_oshape.Size(), @@ -303,7 +384,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, } else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) { TBlob temp_tblob; if (lhs.type_flag_ == out.type_flag_) { - MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(lhs.type_flag_, LType, { Tensor temp_tensor = ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); temp_tblob = TBlob(temp_tensor); @@ -311,8 +392,8 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); BinaryBroadcastCompute( attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); - } else { - MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + } else if (rhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(rhs.type_flag_, RType, { Tensor temp_tensor = ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); temp_tblob = TBlob(temp_tensor); @@ -320,6 +401,25 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); BinaryBroadcastCompute( attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } else { + TBlob temp_tblob_l; + TBlob temp_tblob_r; + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(out.type_flag_, OType, { + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(lhs.Size() + rhs.Size()), s); + TBlob temp_tblob = TBlob(workspace); + temp_tblob_l = TBlob(reinterpret_cast(temp_tblob.dptr_), + lhs.shape_, + temp_tblob.dev_mask(), + temp_tblob.dev_id()); + temp_tblob_r = TBlob(reinterpret_cast(temp_tblob.dptr_) + lhs.Size() + 1, + rhs.shape_, + temp_tblob.dev_mask(), + temp_tblob.dev_id()); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob_l}); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob_r}); + BinaryBroadcastCompute(attrs, ctx, {temp_tblob_l, temp_tblob_r}, req, outputs); } } else { PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); @@ -379,7 +479,7 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, Stream* s = ctx.get_stream(); TBlob temp_tblob; if (lhs.type_flag_ == out.type_flag_) { - MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(lhs.type_flag_, LType, { Tensor temp_tensor = ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); temp_tblob = TBlob(temp_tensor); @@ -387,8 +487,8 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); BinaryBroadcastCompute( attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); - } else { - MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + } else if (rhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(rhs.type_flag_, RType, { Tensor temp_tensor = ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); temp_tblob = TBlob(temp_tensor); @@ -396,12 +496,230 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); BinaryBroadcastCompute( attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } else { + TBlob temp_tblob_l; + TBlob temp_tblob_r; + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(out.type_flag_, OType, { + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(lhs.Size() + rhs.Size()), s); + TBlob temp_tblob = TBlob(workspace); + temp_tblob_l = TBlob(reinterpret_cast(temp_tblob.dptr_), + lhs.shape_, + temp_tblob.dev_mask(), + temp_tblob.dev_id()); + temp_tblob_r = TBlob(reinterpret_cast(temp_tblob.dptr_) + lhs.Size() + 1, + rhs.shape_, + temp_tblob.dev_mask(), + temp_tblob.dev_id()); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob_l}); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob_r}); + BinaryBroadcastCompute(attrs, ctx, {temp_tblob_l, temp_tblob_r}, req, outputs); } return; } MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); } +template +void NumpyBinaryBroadcastIntComputeWithBool(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + + if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) + return; + + if (lhs.type_flag_ == rhs.type_flag_) { + BinaryBroadcastIntComputeWithBool(attrs, ctx, inputs, req, outputs); + return; + } + Stream* s = ctx.get_stream(); + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastIntComputeWithBool( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else if (rhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastIntComputeWithBool( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } else { + TBlob temp_tblob_l; + TBlob temp_tblob_r; + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(out.type_flag_, OType, { + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(lhs.Size() + rhs.Size()), s); + TBlob temp_tblob = TBlob(workspace); + temp_tblob_l = TBlob(reinterpret_cast(temp_tblob.dptr_), + lhs.shape_, + temp_tblob.dev_mask(), + temp_tblob.dev_id()); + temp_tblob_r = TBlob(reinterpret_cast(temp_tblob.dptr_) + lhs.Size() + 1, + rhs.shape_, + temp_tblob.dev_mask(), + temp_tblob.dev_id()); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob_l}); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob_r}); + BinaryBroadcastIntComputeWithBool( + attrs, ctx, {temp_tblob_l, temp_tblob_r}, req, outputs); + } + return; +} + +template +void NumpyBinaryBroadcastIntCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + + if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) + return; + + if (lhs.type_flag_ == rhs.type_flag_) { + BinaryBroadcastIntCompute(attrs, ctx, inputs, req, outputs); + return; + } + Stream* s = ctx.get_stream(); + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH_EXT(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastIntCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else if (rhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH_EXT(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastIntCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } else { + TBlob temp_tblob_l; + TBlob temp_tblob_r; + MXNET_INT_TYPE_SWITCH_EXT(out.type_flag_, OType, { + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(lhs.Size() + rhs.Size()), s); + TBlob temp_tblob = TBlob(workspace); + temp_tblob_l = TBlob(reinterpret_cast(temp_tblob.dptr_), + lhs.shape_, + temp_tblob.dev_mask(), + temp_tblob.dev_id()); + temp_tblob_r = TBlob(reinterpret_cast(temp_tblob.dptr_) + lhs.Size() + 1, + rhs.shape_, + temp_tblob.dev_mask(), + temp_tblob.dev_id()); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob_l}); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob_r}); + BinaryBroadcastIntCompute(attrs, ctx, {temp_tblob_l, temp_tblob_r}, req, outputs); + } + return; +} + +inline bool NumpyBinaryMixedFloatingType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const int ltype = in_attrs->at(0); + const int rtype = in_attrs->at(1); + + if (ltype != -1 && rtype != -1 && (ltype != rtype)) { + // Only when both input types are known and not the same, we enter the mixed-precision mode + TYPE_ASSIGN_CHECK(*out_attrs, 0, common::type_promotion(ltype, rtype)); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); + } + // check if it is float16, float32 or float64. If not, raise error. + CHECK(common::is_float(in_attrs->at(0))) << "Do not support `int` as input.\n"; + return out_attrs->at(0) != -1; +} + +template +void NumpyBinaryMixedFloatingCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + + if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) + return; + + if (lhs.type_flag_ == rhs.type_flag_) { + BinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); + return; + } + Stream* s = ctx.get_stream(); + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } + return; +} + template void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -557,7 +875,7 @@ inline bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, const int rtype = in_attrs->at(1); if (ltype != -1 && rtype != -1 && (ltype != rtype)) { // Only when both input types are known and not the same, we enter the mixed-precision mode - TYPE_ASSIGN_CHECK(*out_attrs, 0, common::np_binary_out_infer_type(ltype, rtype)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, common::type_promotion(ltype, rtype)); } else { return ElemwiseType<2, 1>(attrs, in_attrs, out_attrs); } @@ -586,6 +904,88 @@ inline bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") +inline bool NumpyBinaryMixedIntPrecisionTypeWithBool(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const int ltype = in_attrs->at(0); + const int rtype = in_attrs->at(1); + CHECK(common::is_int(ltype) || ltype == mshadow::kBool) + << "1st input only supports integer types or bool types."; + CHECK(common::is_int(rtype) || rtype == mshadow::kBool) + << "2nd input only supports integer types or bool types."; + if (ltype != -1 && rtype != -1 && (ltype != rtype)) { + // Only when both input types are known and not the same, we enter the mixed-precision mode + TYPE_ASSIGN_CHECK(*out_attrs, 0, common::type_promotion(ltype, rtype)); + } else { + return ElemwiseType<2, 1>(attrs, in_attrs, out_attrs); + } + return true; +} + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_INT_PRECISION_WITH_BOOL(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(2) \ + .set_num_outputs(1) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"lhs", "rhs"}; \ + }) \ + .set_attr("FInferShape", BinaryBroadcastShape) \ + .set_attr("FInferType", NumpyBinaryMixedIntPrecisionTypeWithBool) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs) { \ + return std::vector >{{0, 0}, {1, 0}}; \ + }) \ + .set_attr( \ + "FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") + +inline bool NumpyBinaryMixedIntPrecisionType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const int ltype = in_attrs->at(0); + const int rtype = in_attrs->at(1); + CHECK(common::is_int(ltype)) << "1st input only supports integer types."; + CHECK(common::is_int(rtype)) << "2nd input only supports integer types."; + if (ltype != -1 && rtype != -1 && (ltype != rtype)) { + // Only when both input types are known and not the same, we enter the mixed-precision mode + TYPE_ASSIGN_CHECK(*out_attrs, 0, common::type_promotion(ltype, rtype)); + } else { + return ElemwiseType<2, 1>(attrs, in_attrs, out_attrs); + } + return true; +} + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_INT_PRECISION(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(2) \ + .set_num_outputs(1) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"lhs", "rhs"}; \ + }) \ + .set_attr("FInferShape", BinaryBroadcastShape) \ + .set_attr("FInferType", NumpyBinaryMixedIntPrecisionType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs) { \ + return std::vector >{{0, 0}, {1, 0}}; \ + }) \ + .set_attr( \ + "FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_ diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc index 98a4688002ce..949aad67ab3e 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -130,23 +130,10 @@ NNVM_REGISTER_OP(_npi_lcm_scalar) .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); -NNVM_REGISTER_OP(_npi_bitwise_and) - .set_num_inputs(2) - .set_num_outputs(1) - .set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; - }) - .set_attr("FInferShape", BinaryBroadcastShape) - .set_attr("FInferType", ElemwiseIntType<2, 1>) - .set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector >{{0, 0}, {1, 0}}; - }) - .set_attr("FGradient", MakeZeroGradNodes) - .set_attr("FCompute", BinaryBroadcastIntCompute) - .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") - .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_INT_PRECISION_WITH_BOOL(_npi_bitwise_and) + .set_attr("FCompute", + NumpyBinaryBroadcastIntComputeWithBool) + .set_attr("FGradient", MakeZeroGradNodes); NNVM_REGISTER_OP(_npi_bitwise_and_scalar) .set_num_inputs(1) @@ -163,41 +150,15 @@ NNVM_REGISTER_OP(_npi_bitwise_and_scalar) .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); -NNVM_REGISTER_OP(_npi_bitwise_xor) - .set_num_inputs(2) - .set_num_outputs(1) - .set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; - }) - .set_attr("FInferShape", BinaryBroadcastShape) - .set_attr("FInferType", ElemwiseIntType<2, 1>) - .set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector >{{0, 0}, {1, 0}}; - }) - .set_attr("FGradient", MakeZeroGradNodes) - .set_attr("FCompute", BinaryBroadcastIntCompute) - .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") - .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_INT_PRECISION_WITH_BOOL(_npi_bitwise_xor) + .set_attr("FCompute", + NumpyBinaryBroadcastIntComputeWithBool) + .set_attr("FGradient", MakeZeroGradNodes); -NNVM_REGISTER_OP(_npi_bitwise_or) - .set_num_inputs(2) - .set_num_outputs(1) - .set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; - }) - .set_attr("FInferShape", BinaryBroadcastShape) - .set_attr("FInferType", ElemwiseIntType<2, 1>) - .set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector >{{0, 0}, {1, 0}}; - }) - .set_attr("FGradient", MakeZeroGradNodes) - .set_attr("FCompute", BinaryBroadcastIntCompute) - .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") - .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_INT_PRECISION_WITH_BOOL(_npi_bitwise_or) + .set_attr("FCompute", + NumpyBinaryBroadcastIntComputeWithBool) + .set_attr("FGradient", MakeZeroGradNodes); NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) .set_num_inputs(1) @@ -240,21 +201,6 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar) MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_copysign_scalar) .set_attr("FCompute", BinaryScalarOp::Backward); -inline bool Arctan2OpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); - // check if it is float16, float32 or float64. If not, raise error. - CHECK(common::is_float(in_attrs->at(0))) << "Do not support `int` as input.\n"; - return out_attrs->at(0) != -1; -} - NNVM_REGISTER_OP(_npi_arctan2) .set_num_inputs(2) .set_num_outputs(1) @@ -263,13 +209,17 @@ NNVM_REGISTER_OP(_npi_arctan2) return std::vector{"x1", "x2"}; }) .set_attr("FInferShape", BinaryBroadcastShape) - .set_attr("FInferType", Arctan2OpType) - .set_attr("FCompute", BinaryBroadcastCompute) + .set_attr("FInferType", NumpyBinaryMixedFloatingType) + .set_attr("FCompute", NumpyBinaryMixedFloatingCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2"}) .set_attr("FInplaceOption", [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; }) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) .add_argument("x1", "NDArray-or-Symbol", "The input array") .add_argument("x2", "NDArray-or-Symbol", "The input array"); @@ -283,7 +233,7 @@ NNVM_REGISTER_OP(_backward_npi_arctan2) }) .set_attr( "FCompute", - BinaryBroadcastBackwardUseIn); + NumpyBinaryBackwardUseIn); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_arctan2_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended_thi.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended_thi.cc index 90ecd6e2387a..7fc8d9a9635f 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended_thi.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended_thi.cc @@ -44,24 +44,10 @@ namespace op { .add_argument("data", "NDArray-or-Symbol", "source input") \ .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) -NNVM_REGISTER_OP(_npi_bitwise_left_shift) - .set_num_inputs(2) - .set_num_outputs(1) - .set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; - }) - .set_attr("FInferShape", BinaryBroadcastShape) - .set_attr("FInferType", ElemwiseIntType<2, 1>) - .set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector >{{0, 0}, {1, 0}}; - }) +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_INT_PRECISION(_npi_bitwise_left_shift) .set_attr("FCompute", - BinaryBroadcastCompute) - .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_bitwise_left_shift"}) - .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") - .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); + NumpyBinaryBroadcastIntCompute) + .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_bitwise_left_shift"}); NNVM_REGISTER_OP(_npi_bitwise_left_shift_scalar) .set_num_inputs(1) @@ -126,24 +112,10 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rbitwise_left_shift_scalar) .set_attr("FCompute", BinaryScalarOp::Backward); -NNVM_REGISTER_OP(_npi_bitwise_right_shift) - .set_num_inputs(2) - .set_num_outputs(1) - .set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; - }) - .set_attr("FInferShape", BinaryBroadcastShape) - .set_attr("FInferType", ElemwiseIntType<2, 1>) - .set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector >{{0, 0}, {1, 0}}; - }) +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_INT_PRECISION(_npi_bitwise_right_shift) .set_attr("FCompute", - BinaryBroadcastCompute) - .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_bitwise_right_shift"}) - .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") - .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); + NumpyBinaryBroadcastIntCompute) + .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_bitwise_right_shift"}); NNVM_REGISTER_OP(_npi_bitwise_right_shift_scalar) .set_num_inputs(1) diff --git a/src/operator/numpy/np_elemwise_broadcast_op_lae.cc b/src/operator/numpy/np_elemwise_broadcast_op_lae.cc index 05d83d819dc9..651fbf6fe2eb 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_lae.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_lae.cc @@ -27,9 +27,28 @@ namespace mxnet { namespace op { -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_logaddexp) - .set_attr("FCompute", BinaryBroadcastCompute) - .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_logaddexp"}); +NNVM_REGISTER_OP(_npi_logaddexp) + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"x1", "x2"}; + }) + .set_attr("FInferShape", BinaryBroadcastShape) + .set_attr("FInferType", NumpyBinaryMixedFloatingType) + .set_attr("FCompute", + NumpyBinaryMixedFloatingCompute) + .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_logaddexp"}) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) + .set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) + .add_argument("x1", "NDArray-or-Symbol", "The input array") + .add_argument("x2", "NDArray-or-Symbol", "The input array"); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_logaddexp_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) @@ -49,7 +68,7 @@ NNVM_REGISTER_OP(_backward_npi_logaddexp) }) .set_attr( "FCompute", - BinaryBroadcastBackwardUseIn); + NumpyBinaryBackwardUseIn); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_logaddexp_scalar) .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) diff --git a/src/operator/numpy/np_true_divide-inl.h b/src/operator/numpy/np_true_divide-inl.h index 6424e22ad209..047489f648cc 100644 --- a/src/operator/numpy/np_true_divide-inl.h +++ b/src/operator/numpy/np_true_divide-inl.h @@ -117,7 +117,34 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs& attrs, // Case when types of the 2 input tensors are different if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { // both lhs and rhs are float types, output type is the more precise one - LOG(FATAL) << "not implemented yet..."; + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, { + Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), lhs.dptr(), temp_tblob.dptr()); + }); + }); + } else { + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, { + Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), temp_tblob.dptr(), rhs.dptr()); + }); + }); + } } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { // one is float type, the other is integer type, the output type should be the same as float CHECK_EQ(out.type_flag_, common::is_float(lhs.type_flag_) ? lhs.type_flag_ : rhs.type_flag_) @@ -213,7 +240,46 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs, } else { if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { // lhs and rhs have different float types, the output is the more precise one - LOG(FATAL) << "not implemented yet..."; + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, { + Kernel, + xpu>::template LaunchEx(s, + new_oshape.Size(), + req[0], + lstride, + rstride, + oshape, + lhs.dptr(), + temp_tblob.dptr(), + out.dptr()); + }); + } else { + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, { + Kernel, + xpu>::template LaunchEx(s, + new_oshape.Size(), + req[0], + lstride, + rstride, + oshape, + temp_tblob.dptr(), + rhs.dptr(), + out.dptr()); + }); + } } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { // one of lhs and rhs is float, the output is the same type as the float one if (common::is_float(lhs.type_flag_)) { diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc index 13fb72ca970a..99cd4c4718bc 100644 --- a/src/operator/numpy/np_true_divide.cc +++ b/src/operator/numpy/np_true_divide.cc @@ -30,7 +30,7 @@ namespace op { int TrueDivideOutType(int ltype, int rtype) { if (common::is_float(ltype) && common::is_float(rtype)) { // If both inputs are float, return the one with the higher precision - return common::get_more_precise_type(ltype, rtype); + return common::type_promotion(ltype, rtype); } else if (common::is_float(ltype) || common::is_float(rtype)) { // If only one of the inputs is float, return that float type return (common::is_float(ltype)) ? ltype : rtype; @@ -74,6 +74,10 @@ NNVM_REGISTER_OP(_npi_true_divide) [](const NodeAttrs& attrs) { return std::vector >{{0, 0}, {1, 0}}; }) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) .set_attr("FCompute", TrueDivideBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_div"}) .add_argument("lhs", "NDArray-or-Symbol", "Dividend array") diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index d5ba8c2f60c0..b8f2902444fa 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -1359,8 +1359,8 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape); Stream* s = ctx.get_stream(); bool isCPU = std::is_same::value; - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, { - MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { + MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(outputs[0].type_flag_, OType, { mshadow::Shape in_shape; mshadow::Shape out_shape; for (int i = 0; i < MXNET_SPECIAL_MAX_NDIM; ++i) { diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index fbf42c515225..20d874dbd826 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -236,6 +236,43 @@ void BinaryBroadcastIntCompute(const nnvm::NodeAttrs& attrs, } } +template +void BinaryBroadcastIntComputeWithBool(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (outputs[0].shape_.Size() == 0U) + return; + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact( + inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + ElemwiseBinaryOp::ComputeIntWithBool(attrs, ctx, inputs, req, outputs); + } else { + if (req[0] == kNullOp) + return; + mshadow::Stream* s = ctx.get_stream(); + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>::template LaunchEx( + s, + new_oshape.Size(), + req[0], + lstride, + rstride, + oshape, + inputs[0].dptr(), + inputs[1].dptr(), + outputs[0].dptr()); + }); + }); + } +} + template void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -256,7 +293,7 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, if (outputs[0].type_flag_ == mshadow::kBool) { LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; } - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_EXT(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { broadcast::BinaryBroadcastComputeImpl(s, req[0], diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index b4a7498f0eba..4f36b8acd404 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -461,6 +461,31 @@ class ElemwiseBinaryOp : public OpBase { }); } + template + static void ComputeIntWithBool(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + if (req[0] == kNullOp) + return; + Stream* s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(outputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / + DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch( + s, size, outputs[0].dptr(), inputs[0].dptr(), inputs[1].dptr()); + } + }); + }); + } + template static void Compute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -477,7 +502,7 @@ class ElemwiseBinaryOp : public OpBase { LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; } MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_EXT(outputs[0].type_flag_, DType, { const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + DataType::kLanes - 1) / DataType::kLanes; diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 38949f1769ed..f516a7858c62 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -243,7 +243,7 @@ class UnaryOp : public OpBase { const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_EXT(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { if (inputs[0].Size() != 0) { mxnet_op::Kernel, cpu>::Launch( @@ -275,7 +275,7 @@ class UnaryOp : public OpBase { UnaryOp::Compute(attrs, ctx, inputs, req, outputs); } else { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - MXNET_INT_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(inputs[0].type_flag_, IType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { if (inputs[0].Size() != 0) { mxnet_op::Kernel, xpu>::Launch( @@ -294,7 +294,7 @@ class UnaryOp : public OpBase { const std::vector& req, const std::vector& outputs) { mshadow::Stream* s = ctx.get_stream(); - MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { if (inputs[0].Size() != 0) { mxnet_op::Kernel, xpu>::Launch( @@ -311,7 +311,7 @@ class UnaryOp : public OpBase { const std::vector& req, const std::vector& outputs) { mshadow::Stream* s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(inputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { if (inputs[0].Size() != 0) { mxnet_op::Kernel, xpu>::Launch( @@ -700,7 +700,7 @@ void AroundOpForward(const nnvm::NodeAttrs& attrs, s, out_data.Size(), out_data.dptr(), in_data.dptr()); }); } else { - MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_EXT(out_data.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { Kernel, xpu>::Launch( s, out_data.Size(), out_data.dptr(), in_data.dptr(), param.decimals); diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 6060f32a9587..b4dcf0b4f485 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -263,7 +263,7 @@ def _add_workload_percentile(): q3 = np.array([25, 50, 100]) q4 = 65 x4 = np.arange(11 * 2).reshape(11, 1, 2, 1) - x5 = np.array([0, np.nan]) + x5 = np.array([0, _np.nan]) OpArgMngr.add_workload('percentile', x1, q1, None, None, None) OpArgMngr.add_workload('percentile', x1, q1, None, None, None, 'linear') @@ -760,9 +760,9 @@ def _add_workload_tril(): [[1, 1], [0, 0]], ], dtype=dt) OpArgMngr.add_workload('tril', a) - arr = np.array([[1, 1, np.inf], + arr = np.array([[1, 1, _np.inf], [1, 1, 1], - [np.inf, 1, 1]]) + [_np.inf, 1, 1]]) OpArgMngr.add_workload('tril', arr) OpArgMngr.add_workload('tril', np.zeros((3, 3), dtype=dt)) import mxnet as mx @@ -780,9 +780,9 @@ def _add_workload_triu(): [[1, 1], [0, 0]], ], dtype=dt) OpArgMngr.add_workload('triu', a) - arr = np.array([[1, 1, np.inf], + arr = np.array([[1, 1, _np.inf], [1, 1, 1], - [np.inf, 1, 1]]) + [_np.inf, 1, 1]]) OpArgMngr.add_workload('triu', arr) OpArgMngr.add_workload('triu', np.zeros((3, 3), dtype=dt)) @@ -896,8 +896,8 @@ def _add_workload_einsum(): def _add_workload_expm1(): OpArgMngr.add_workload('expm1', np.random.uniform(size=(4, 1))) OpArgMngr.add_workload('expm1', np.random.uniform(size=(1, 1))) - OpArgMngr.add_workload('expm1', np.array([np.inf])) - OpArgMngr.add_workload('expm1', np.array([-np.inf])) + OpArgMngr.add_workload('expm1', np.array([_np.inf])) + OpArgMngr.add_workload('expm1', np.array([-_np.inf])) OpArgMngr.add_workload('expm1', np.array([0.])) OpArgMngr.add_workload('expm1', np.array([-0.])) @@ -908,10 +908,10 @@ def _add_workload_argmax(): OpArgMngr.add_workload('argmax', np.random.uniform(size=(4, 5, 6, 7, 8)), 2) OpArgMngr.add_workload('argmax', np.random.uniform(size=(4, 5, 6, 7, 8)), 3) OpArgMngr.add_workload('argmax', np.random.uniform(size=(4, 5, 6, 7, 8)), 4) - # OpArgMngr.add_workload('argmax', np.array([0, 1, 2, 3, np.nan])) - # OpArgMngr.add_workload('argmax', np.array([0, 1, 2, np.nan, 3])) - # OpArgMngr.add_workload('argmax', np.array([np.nan, 0, 1, 2, 3])) - # OpArgMngr.add_workload('argmax', np.array([np.nan, 0, np.nan, 2, 3])) + # OpArgMngr.add_workload('argmax', np.array([0, 1, 2, 3, _np.nan])) + # OpArgMngr.add_workload('argmax', np.array([0, 1, 2, _np.nan, 3])) + # OpArgMngr.add_workload('argmax', np.array([_np.nan, 0, 1, 2, 3])) + # OpArgMngr.add_workload('argmax', np.array([_np.nan, 0, _np.nan, 2, 3])) OpArgMngr.add_workload('argmax', np.array([False, False, False, False, True])) OpArgMngr.add_workload('argmax', np.array([False, False, False, True, False])) OpArgMngr.add_workload('argmax', np.array([True, False, False, False, False])) @@ -924,10 +924,10 @@ def _add_workload_argmin(): OpArgMngr.add_workload('argmin', np.random.uniform(size=(4, 5, 6, 7, 8)), 2) OpArgMngr.add_workload('argmin', np.random.uniform(size=(4, 5, 6, 7, 8)), 3) OpArgMngr.add_workload('argmin', np.random.uniform(size=(4, 5, 6, 7, 8)), 4) - # OpArgMngr.add_workload('argmin', np.array([0, 1, 2, 3, np.nan])) - # OpArgMngr.add_workload('argmin', np.array([0, 1, 2, np.nan, 3])) - # OpArgMngr.add_workload('argmin', np.array([np.nan, 0, 1, 2, 3])) - # OpArgMngr.add_workload('argmin', np.array([np.nan, 0, np.nan, 2, 3])) + # OpArgMngr.add_workload('argmin', np.array([0, 1, 2, 3, _np.nan])) + # OpArgMngr.add_workload('argmin', np.array([0, 1, 2, _np.nan, 3])) + # OpArgMngr.add_workload('argmin', np.array([_np.nan, 0, 1, 2, 3])) + # OpArgMngr.add_workload('argmin', np.array([_np.nan, 0, _np.nan, 2, 3])) OpArgMngr.add_workload('argmin', np.array([False, False, False, False, True])) OpArgMngr.add_workload('argmin', np.array([False, False, False, True, False])) OpArgMngr.add_workload('argmin', np.array([True, False, False, False, False])) @@ -1004,7 +1004,7 @@ def _add_workload_clip(): # OpArgMngr.add_workload('clip', np.array([0, 1, 2, 3, 4, 5, 6, 7]), 3) # OpArgMngr.add_workload('clip', np.array([0, 1, 2, 3, 4, 5, 6, 7]), a_min=3) # OpArgMngr.add_workload('clip', np.array([0, 1, 2, 3, 4, 5, 6, 7]), a_max=4) - OpArgMngr.add_workload('clip', np.array([-2., np.nan, 0.5, 3., 0.25, np.nan]), -1, 1) + OpArgMngr.add_workload('clip', np.array([-2., _np.nan, 0.5, 3., 0.25, _np.nan]), -1, 1) def _add_workload_cumsum(): @@ -1311,13 +1311,13 @@ def _add_workload_delete(): def _add_workload_var(array_pool): OpArgMngr.add_workload('var', array_pool['4x1']) - OpArgMngr.add_workload('var', np.array([np.float16(1.)])) + OpArgMngr.add_workload('var', np.array([_np.float16(1.)])) OpArgMngr.add_workload('var', np.array([1])) OpArgMngr.add_workload('var', np.array([1.])) OpArgMngr.add_workload('var', np.array([[1, 2, 3], [4, 5, 6]])) OpArgMngr.add_workload('var', np.array([[1, 2, 3], [4, 5, 6]]), 0) OpArgMngr.add_workload('var', np.array([[1, 2, 3], [4, 5, 6]]), 1) - OpArgMngr.add_workload('var', np.array([np.nan])) + OpArgMngr.add_workload('var', np.array([_np.nan])) OpArgMngr.add_workload('var', np.array([1, -1, 1, -1])) OpArgMngr.add_workload('var', np.array([1,2,3,4], dtype='f8')) @@ -1333,7 +1333,7 @@ def _add_workload_full_like(array_pool): OpArgMngr.add_workload('full_like', array_pool['4x1'], 1) OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(1,3,4), dtype='float64'), 1) OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(9,3,1)), 2, dtype=np.int64) - OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(9,3)), np.nan) + OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(9,3)), _np.nan) OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(2,0)), 0, dtype=np.float32) @@ -1357,13 +1357,13 @@ def _add_workload_meshgrid(): def _add_workload_abs(): OpArgMngr.add_workload('abs', np.random.uniform(size=(11,)).astype(np.float32)) OpArgMngr.add_workload('abs', np.random.uniform(size=(5,)).astype(np.float64)) - OpArgMngr.add_workload('abs', np.array([np.inf, -np.inf, np.nan])) + OpArgMngr.add_workload('abs', np.array([_np.inf, -_np.inf, _np.nan])) def _add_workload_fabs(): OpArgMngr.add_workload('fabs', np.random.uniform(size=(11,)).astype(np.float32)) OpArgMngr.add_workload('fabs', np.random.uniform(size=(5,)).astype(np.float64)) - OpArgMngr.add_workload('fabs', np.array([np.inf, -np.inf, np.nan])) + OpArgMngr.add_workload('fabs', np.array([_np.inf, -_np.inf, _np.nan])) def _add_workload_add(array_pool): @@ -1381,10 +1381,10 @@ def _add_workload_arctan2(): OpArgMngr.add_workload('arctan2', np.array([np.PZERO, np.NZERO]), np.array([1, 1])) OpArgMngr.add_workload('arctan2', np.array([-1, -1]), np.array([np.PZERO, np.NZERO])) OpArgMngr.add_workload('arctan2', np.array([1, 1]), np.array([np.PZERO, np.NZERO])) - OpArgMngr.add_workload('arctan2', np.array([1, -1, 1, -1]), np.array([-np.inf, -np.inf, np.inf, np.inf])) - OpArgMngr.add_workload('arctan2', np.array([np.inf, -np.inf]), np.array([1, 1])) - OpArgMngr.add_workload('arctan2', np.array([np.inf, -np.inf]), np.array([-np.inf, -np.inf])) - OpArgMngr.add_workload('arctan2', np.array([np.inf, -np.inf]), np.array([np.inf, np.inf])) + OpArgMngr.add_workload('arctan2', np.array([1, -1, 1, -1]), np.array([-_np.inf, -_np.inf, _np.inf, _np.inf])) + OpArgMngr.add_workload('arctan2', np.array([_np.inf, -_np.inf]), np.array([1, 1])) + OpArgMngr.add_workload('arctan2', np.array([_np.inf, -_np.inf]), np.array([-_np.inf, -_np.inf])) + OpArgMngr.add_workload('arctan2', np.array([_np.inf, -_np.inf]), np.array([_np.inf, _np.inf])) def _add_workload_copysign(): @@ -1442,7 +1442,7 @@ def _add_workload_interp(): fp0 = np.linspace(0, 1, 5) x0 = np.linspace(0, 1, 50) xp1 = np.array([1, 2, 3, 4]) - fp1 = np.array([1, 2, np.inf, 4]) + fp1 = np.array([1, 2, _np.inf, 4]) x1 = np.array([1, 2, 2.5, 3, 4]) xp2 = np.arange(0, 10, 0.0001) fp2 = np.sin(xp2) @@ -1472,14 +1472,14 @@ def _add_workload_interp(): def _add_workload_hypot(): OpArgMngr.add_workload('hypot', np.array(1), np.array(1)) OpArgMngr.add_workload('hypot', np.array(0), np.array(0)) - OpArgMngr.add_workload('hypot', np.array(np.nan), np.array(np.nan)) - OpArgMngr.add_workload('hypot', np.array(np.nan), np.array(1)) - OpArgMngr.add_workload('hypot', np.array(np.nan), np.array(np.inf)) - OpArgMngr.add_workload('hypot', np.array(np.inf), np.array(np.nan)) - OpArgMngr.add_workload('hypot', np.array(np.inf), np.array(0)) - OpArgMngr.add_workload('hypot', np.array(0), np.array(np.inf)) - OpArgMngr.add_workload('hypot', np.array(np.inf), np.array(np.inf)) - OpArgMngr.add_workload('hypot', np.array(np.inf), np.array(23.0)) + OpArgMngr.add_workload('hypot', np.array(_np.nan), np.array(_np.nan)) + OpArgMngr.add_workload('hypot', np.array(_np.nan), np.array(1)) + OpArgMngr.add_workload('hypot', np.array(_np.nan), np.array(_np.inf)) + OpArgMngr.add_workload('hypot', np.array(_np.inf), np.array(_np.nan)) + OpArgMngr.add_workload('hypot', np.array(_np.inf), np.array(0)) + OpArgMngr.add_workload('hypot', np.array(0), np.array(_np.inf)) + OpArgMngr.add_workload('hypot', np.array(_np.inf), np.array(_np.inf)) + OpArgMngr.add_workload('hypot', np.array(_np.inf), np.array(23.0)) def _add_workload_lcm(): @@ -1673,8 +1673,8 @@ def _signs(dt): for ct in [np.float16, np.float32, np.float64]: fone = np.array(1.0, dtype=ct) fzer = np.array(0.0, dtype=ct) - finf = np.array(np.inf, dtype=ct) - fnan = np.array(np.nan, dtype=ct) + finf = np.array(_np.inf, dtype=ct) + fnan = np.array(_np.nan, dtype=ct) # OpArgMngr.add_workload('remainder', fone, fzer) # failed OpArgMngr.add_workload('remainder', fone, fnan) OpArgMngr.add_workload('remainder', finf, fone) @@ -1734,13 +1734,13 @@ def _add_workload_log(array_pool): def _add_workload_log2(array_pool): OpArgMngr.add_workload('log2', array_pool['4x1']) OpArgMngr.add_workload('log2', np.array(2.**65)) - OpArgMngr.add_workload('log2', np.array(np.inf)) + OpArgMngr.add_workload('log2', np.array(_np.inf)) OpArgMngr.add_workload('log2', np.array(1.)) def _add_workload_log1p(): OpArgMngr.add_workload('log1p', np.array(-1.)) - OpArgMngr.add_workload('log1p', np.array(np.inf)) + OpArgMngr.add_workload('log1p', np.array(_np.inf)) OpArgMngr.add_workload('log1p', np.array(1e-6)) @@ -1749,7 +1749,7 @@ def _add_workload_log10(array_pool): def _add_workload_sqrt(): - OpArgMngr.add_workload('sqrt', np.array([1, np.PZERO, np.NZERO, np.inf, np.nan])) + OpArgMngr.add_workload('sqrt', np.array([1, np.PZERO, np.NZERO, _np.inf, _np.nan])) def _add_workload_square(): @@ -1758,8 +1758,8 @@ def _add_workload_square(): def _add_workload_cbrt(): OpArgMngr.add_workload('cbrt', np.array(-2.5**3, dtype=np.float32)) - OpArgMngr.add_workload('cbrt', np.array([1., 2., -3., np.inf, -np.inf])**3) - OpArgMngr.add_workload('cbrt', np.array([np.inf, -np.inf, np.nan])) + OpArgMngr.add_workload('cbrt', np.array([1., 2., -3., _np.inf, -_np.inf])**3) + OpArgMngr.add_workload('cbrt', np.array([_np.inf, -_np.inf, _np.nan])) def _add_workload_reciprocal(): @@ -1983,8 +1983,8 @@ def _add_workload_equal(array_pool): # TODO(junwu): fp16 does not work yet with TVM generated ops # OpArgMngr.add_workload('equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) OpArgMngr.add_workload('equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32)) - # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan - # OpArgMngr.add_workload('equal', np.array([np.nan]), np.array([np.nan])) + # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with _np.nan + # OpArgMngr.add_workload('equal', np.array([_np.nan]), np.array([_np.nan])) OpArgMngr.add_workload('equal', array_pool['4x1'], array_pool['1x2']) @@ -1992,8 +1992,8 @@ def _add_workload_not_equal(array_pool): # TODO(junwu): fp16 does not work yet with TVM generated ops # OpArgMngr.add_workload('not_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) OpArgMngr.add_workload('not_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32)) - # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan - # OpArgMngr.add_workload('not_equal', np.array([np.nan]), np.array([np.nan])) + # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with _np.nan + # OpArgMngr.add_workload('not_equal', np.array([_np.nan]), np.array([_np.nan])) OpArgMngr.add_workload('not_equal', array_pool['4x1'], array_pool['1x2']) @@ -2004,8 +2004,8 @@ def _add_workload_greater(array_pool): OpArgMngr.add_workload('greater', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('greater', array_pool['4x1'], 2) OpArgMngr.add_workload('greater', 2, array_pool['4x1']) - # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan - # OpArgMngr.add_workload('greater', np.array([np.nan]), np.array([np.nan])) + # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with _np.nan + # OpArgMngr.add_workload('greater', np.array([_np.nan]), np.array([_np.nan])) def _add_workload_greater_equal(array_pool): @@ -2015,8 +2015,8 @@ def _add_workload_greater_equal(array_pool): OpArgMngr.add_workload('greater_equal', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('greater_equal', array_pool['4x1'], 2) OpArgMngr.add_workload('greater_equal', 2, array_pool['4x1']) - # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan - # OpArgMngr.add_workload('greater_equal', np.array([np.nan]), np.array([np.nan])) + # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with _np.nan + # OpArgMngr.add_workload('greater_equal', np.array([_np.nan]), np.array([_np.nan])) def _add_workload_less(array_pool): @@ -2026,8 +2026,8 @@ def _add_workload_less(array_pool): OpArgMngr.add_workload('less', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('less', array_pool['4x1'], 2) OpArgMngr.add_workload('less', 2, array_pool['4x1']) - # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan - # OpArgMngr.add_workload('less', np.array([np.nan]), np.array([np.nan])) + # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with _np.nan + # OpArgMngr.add_workload('less', np.array([_np.nan]), np.array([_np.nan])) def _add_workload_less_equal(array_pool): @@ -2037,8 +2037,8 @@ def _add_workload_less_equal(array_pool): OpArgMngr.add_workload('less_equal', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('less_equal', array_pool['4x1'], 2) OpArgMngr.add_workload('less_equal', 2, array_pool['4x1']) - # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan - # OpArgMngr.add_workload('less_equal', np.array([np.nan]), np.array([np.nan])) + # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with _np.nan + # OpArgMngr.add_workload('less_equal', np.array([_np.nan]), np.array([_np.nan])) def _add_workload_logical_and(array_pool): @@ -2240,8 +2240,8 @@ def _add_workload_polyval(): def _add_workload_linalg_cond(): A = np.array([[1., 0, 1], [0, -2., 0], [0, 0, 3.]]) - OpArgMngr.add_workload('linalg.cond', A, np.inf) - OpArgMngr.add_workload('linalg.cond', A, -np.inf) + OpArgMngr.add_workload('linalg.cond', A, _np.inf) + OpArgMngr.add_workload('linalg.cond', A, -_np.inf) OpArgMngr.add_workload('linalg.cond', A, 1) OpArgMngr.add_workload('linalg.cond', A, -1) OpArgMngr.add_workload('linalg.cond', A, 'fro') @@ -2286,22 +2286,22 @@ def _add_workload_linalg_multi_dot(): def _add_workload_heaviside(): - x = np.array([[-30.0, -0.1, 0.0, 0.2], [7.5, np.nan, np.inf, -np.inf]], dtype=np.float64) + x = np.array([[-30.0, -0.1, 0.0, 0.2], [7.5, _np.nan, _np.inf, -_np.inf]], dtype=np.float64) OpArgMngr.add_workload('heaviside', x, 0.5) OpArgMngr.add_workload('heaviside', x, 1.0) x = x.astype(np.float32) - OpArgMngr.add_workload('heaviside', x, np.float32(0.5)) - OpArgMngr.add_workload('heaviside', x, np.float32(1.0)) + OpArgMngr.add_workload('heaviside', x, _np.float32(0.5)) + OpArgMngr.add_workload('heaviside', x, _np.float32(1.0)) def _add_workload_spacing(): - OpArgMngr.add_workload('spacing', np.float64(1)) - OpArgMngr.add_workload('spacing', np.float32(1)) - OpArgMngr.add_workload('spacing', np.inf) - OpArgMngr.add_workload('spacing', -np.inf) - OpArgMngr.add_workload('spacing', np.float64(1e30)) - OpArgMngr.add_workload('spacing', np.float32(1e30)) + OpArgMngr.add_workload('spacing', _np.float64(1)) + OpArgMngr.add_workload('spacing', _np.float32(1)) + OpArgMngr.add_workload('spacing', _np.inf) + OpArgMngr.add_workload('spacing', -_np.inf) + OpArgMngr.add_workload('spacing', _np.float64(1e30)) + OpArgMngr.add_workload('spacing', _np.float32(1e30)) def _add_workload_allclose(): @@ -2548,14 +2548,14 @@ def _add_workload_interp(): x0 = np.linspace(0, 1, 50) x1 = 0 x2 = .3 - x3 = np.float32(.3) + x3 = _np.float32(.3) OpArgMngr.add_workload('interp', x0, x, y) OpArgMngr.add_workload('interp', x1, x, y) OpArgMngr.add_workload('interp', x2, x, y) OpArgMngr.add_workload('interp', x3, x, y) x = np.array([1, 2, 2.5, 3, 4]) xp = np.array([1, 2, 3, 4]) - fp = np.array([1, 2, np.inf, 4]) + fp = np.array([1, 2, _np.inf, 4]) OpArgMngr.add_workload('interp', x, xp, fp) @@ -2574,7 +2574,7 @@ def _add_workload_intersect1d(): def _add_workload_isclose(): a = np.array([1e10,1e-7]) b = np.array([1.00001e10,1e-8]) - c = np.array([1.0, np.nan]) + c = np.array([1.0, _np.nan]) d = np.array([0.0, 0.0]) e = np.array([1e-100, 1e-7]) OpArgMngr.add_workload('isclose', a, b) @@ -2633,56 +2633,56 @@ def _add_workload_msort(): def _add_workload_nanargmax(): - a = np.array([[np.nan, 4], [2, 3]]) + a = np.array([[_np.nan, 4], [2, 3]]) OpArgMngr.add_workload('nanargmax', a) OpArgMngr.add_workload('nanargmax', a, axis=0) OpArgMngr.add_workload('nanargmax', a, axis=1) def _add_workload_nanargmin(): - a = np.array([[np.nan, 4], [2, 3]]) + a = np.array([[_np.nan, 4], [2, 3]]) OpArgMngr.add_workload('nanargmin', a) OpArgMngr.add_workload('nanargmin', a, axis=0) OpArgMngr.add_workload('nanargmin', a, axis=1) def _add_workload_nancumprod(): - a = np.array([[1, 2], [3, np.nan]]) + a = np.array([[1, 2], [3, _np.nan]]) OpArgMngr.add_workload('nancumprod', a) OpArgMngr.add_workload('nancumprod', a, axis=0) OpArgMngr.add_workload('nancumprod', a, axis=1) def _add_workload_nancumsum(): - a = np.array([[1, 2], [3, np.nan]]) + a = np.array([[1, 2], [3, _np.nan]]) OpArgMngr.add_workload('nancumsum', a) OpArgMngr.add_workload('nancumsum', a, axis=0) OpArgMngr.add_workload('nancumsum', a, axis=1) def _add_workload_nanmax(): - a = np.array([[1, 2], [3, np.nan]]) + a = np.array([[1, 2], [3, _np.nan]]) OpArgMngr.add_workload('nanmax', a) OpArgMngr.add_workload('nanmax', a, axis=0) OpArgMngr.add_workload('nanmax', a, axis=1) def _add_workload_nanmedian(): - a = np.array([[10.0, np.nan, 4], [3, 2, 1]]) + a = np.array([[10.0, _np.nan, 4], [3, 2, 1]]) OpArgMngr.add_workload('nanmedian', a) OpArgMngr.add_workload('nanmedian', a, axis=0) OpArgMngr.add_workload('nanmedian', a, axis=1) def _add_workload_nanmin(): - a = np.array([[1, 2], [3, np.nan]]) + a = np.array([[1, 2], [3, _np.nan]]) OpArgMngr.add_workload('nanmin', a) OpArgMngr.add_workload('nanmin', a, axis=0) OpArgMngr.add_workload('nanmin', a, axis=1) def _add_workload_nanpercentile(): - a = np.array([[10.0, np.nan, 4], [3, 2, 1]]) + a = np.array([[10.0, _np.nan, 4], [3, 2, 1]]) OpArgMngr.add_workload('nanpercentile', a, 50) OpArgMngr.add_workload('nanpercentile', a, 50, axis=0) OpArgMngr.add_workload('nanpercentile', a, 50, axis=1) @@ -2695,8 +2695,8 @@ def _add_workload_nanpercentile(): def _add_workload_nanprod(): a = 1 - b = np.array([1, np.nan]) - c = np.array([[1, 2], [3, np.nan]]) + b = np.array([1, _np.nan]) + c = np.array([[1, 2], [3, _np.nan]]) OpArgMngr.add_workload('nanprod', a) OpArgMngr.add_workload('nanprod', b) OpArgMngr.add_workload('nanprod', c) @@ -2704,7 +2704,7 @@ def _add_workload_nanprod(): def _add_workload_nanquantile(): - a = np.array([[10.0, np.nan, 4], [3, 2, 1]]) + a = np.array([[10.0, _np.nan, 4], [3, 2, 1]]) OpArgMngr.add_workload('nanquantile', a, 0.4) OpArgMngr.add_workload('nanquantile', a, 0.4, axis=0) OpArgMngr.add_workload('nanquantile', a, 0.4, axis=1) @@ -2717,7 +2717,7 @@ def _add_workload_nanquantile(): def _add_workload_nanstd(): OpArgMngr.add_workload('nanstd', np.random.uniform(size=(4, 1))) - A = np.array([[1, 2, 3], [4, np.nan, 6]]) + A = np.array([[1, 2, 3], [4, _np.nan, 6]]) OpArgMngr.add_workload('nanstd', A) OpArgMngr.add_workload('nanstd', A, 0) OpArgMngr.add_workload('nanstd', A, 1) @@ -2729,8 +2729,8 @@ def _add_workload_nanstd(): def _add_workload_nansum(): a = 1 - b = np.array([1, np.nan]) - c = np.array([[1, 2], [3, np.nan]]) + b = np.array([1, _np.nan]) + c = np.array([[1, 2], [3, _np.nan]]) OpArgMngr.add_workload('nansum', a) OpArgMngr.add_workload('nansum', b) OpArgMngr.add_workload('nansum', c) @@ -2739,7 +2739,7 @@ def _add_workload_nansum(): def _add_workload_nanvar(): OpArgMngr.add_workload('nanvar', np.random.uniform(size=(4, 1))) - A = np.array([[1, 2, 3], [4, np.nan, 6]]) + A = np.array([[1, 2, 3], [4, _np.nan, 6]]) OpArgMngr.add_workload('nanvar', A) OpArgMngr.add_workload('nanvar', A, 0) OpArgMngr.add_workload('nanvar', A, 1) @@ -2960,9 +2960,9 @@ def _add_workload_trapz(): def _add_workload_tril_indices_from(): for dt in ['float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8']: OpArgMngr.add_workload('tril_indices_from', np.ones((2, 2), dtype=dt)) - arr = np.array([[1, 1, np.inf], + arr = np.array([[1, 1, _np.inf], [1, 1, 1], - [np.inf, 1, 1]]) + [_np.inf, 1, 1]]) OpArgMngr.add_workload('tril_indices_from', arr) OpArgMngr.add_workload('tril_indices_from', np.zeros((3, 3), dtype=dt)) diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 8558c3d561e7..edec96c34ed8 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -622,7 +622,7 @@ def test_nd_no_format(): @use_np @pytest.mark.serial def test_np_ndarray_indexing(): - def np_int(index, int_type=np.int32): + def np_int(index, int_type=_np.int32): """ Helper function for testing indexing that converts slices to slices of ints or None, and tuples to tuples of ints or None. @@ -801,70 +801,70 @@ def test_setitem_autograd(np_array, index): # Basic indexing # Single int as index 0, - np.int32(0), - np.int64(0), + _np.int32(0), + _np.int64(0), np.array(0, dtype='int32'), np.array(0, dtype='int64'), 5, - np.int32(5), - np.int64(5), + _np.int32(5), + _np.int64(5), np.array(5, dtype='int32'), np.array(5, dtype='int64'), -1, - np.int32(-1), - np.int64(-1), + _np.int32(-1), + _np.int64(-1), np.array(-1, dtype='int32'), np.array(-1, dtype='int64'), # Slicing as index slice(5), - np_int(slice(5), np.int32), - np_int(slice(5), np.int64), + np_int(slice(5), _np.int32), + np_int(slice(5), _np.int64), slice(1, 5), - np_int(slice(1, 5), np.int32), - np_int(slice(1, 5), np.int64), + np_int(slice(1, 5), _np.int32), + np_int(slice(1, 5), _np.int64), slice(1, 5, 2), slice(1, 2, 2), - np_int(slice(1, 5, 2), np.int32), - np_int(slice(1, 5, 2), np.int64), + np_int(slice(1, 5, 2), _np.int32), + np_int(slice(1, 5, 2), _np.int64), slice(7, 0, -1), np_int(slice(7, 0, -1)), - np_int(slice(7, 0, -1), np.int64), + np_int(slice(7, 0, -1), _np.int64), slice(None, 6), np_int(slice(None, 6)), - np_int(slice(None, 6), np.int64), + np_int(slice(None, 6), _np.int64), slice(None, 6, 3), np_int(slice(None, 6, 3)), - np_int(slice(None, 6, 3), np.int64), + np_int(slice(None, 6, 3), _np.int64), slice(1, None), np_int(slice(1, None)), - np_int(slice(1, None), np.int64), + np_int(slice(1, None), _np.int64), slice(1, None, 3), np_int(slice(1, None, 3)), - np_int(slice(1, None, 3), np.int64), + np_int(slice(1, None, 3), _np.int64), slice(None, None, 2), np_int(slice(None, None, 2)), - np_int(slice(None, None, 2), np.int64), + np_int(slice(None, None, 2), _np.int64), slice(None, None, -1), np_int(slice(None, None, -1)), - np_int(slice(None, None, -1), np.int64), + np_int(slice(None, None, -1), _np.int64), slice(None, None, -2), - np_int(slice(None, None, -2), np.int32), - np_int(slice(None, None, -2), np.int64), + np_int(slice(None, None, -2), _np.int32), + np_int(slice(None, None, -2), _np.int64), # Multiple ints as indices (1, 2, 3), np_int((1, 2, 3)), - np_int((1, 2, 3), np.int64), + np_int((1, 2, 3), _np.int64), (-1, -2, -3), np_int((-1, -2, -3)), - np_int((-1, -2, -3), np.int64), + np_int((-1, -2, -3), _np.int64), (1, 2, 3, 4), np_int((1, 2, 3, 4)), - np_int((1, 2, 3, 4), np.int64), + np_int((1, 2, 3, 4), _np.int64), (-4, -3, -2, -1), (-4, mx.np.array(-3, dtype='int32'), -2, -1), (-4, mx.np.array(-3, dtype='int64'), -2, -1), np_int((-4, -3, -2, -1)), - np_int((-4, -3, -2, -1), np.int64), + np_int((-4, -3, -2, -1), _np.int64), # slice(None) as indices (slice(None), slice(None), 1, 8), (slice(None), slice(None), np.array(1, dtype='int32'), 8), @@ -873,26 +873,26 @@ def test_setitem_autograd(np_array, index): (slice(None), slice(None), 1, -8), (slice(None), slice(None), -1, -8), np_int((slice(None), slice(None), 1, 8)), - np_int((slice(None), slice(None), 1, 8), np.int64), + np_int((slice(None), slice(None), 1, 8), _np.int64), (slice(None), slice(None), 1, 8), np_int((slice(None), slice(None), -1, -8)), - np_int((slice(None), slice(None), -1, -8), np.int64), + np_int((slice(None), slice(None), -1, -8), _np.int64), (slice(None), 2, slice(1, 5), 1), np_int((slice(None), 2, slice(1, 5), 1)), - np_int((slice(None), 2, slice(1, 5), 1), np.int64), + np_int((slice(None), 2, slice(1, 5), 1), _np.int64), # Mixture of ints and slices as indices (slice(None, None, -1), 2, slice(1, 5), 1), np_int((slice(None, None, -1), 2, slice(1, 5), 1)), - np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64), + np_int((slice(None, None, -1), 2, slice(1, 5), 1), _np.int64), (slice(None, None, -1), 2, slice(1, 7, 2), 1), np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)), - np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64), + np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), _np.int64), (slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))), - np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64), + np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), _np.int64), (slice(1, 8, 2), 1, slice(3, 8), 2), np_int((slice(1, 8, 2), 1, slice(3, 8), 2)), - np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64), + np_int((slice(1, 8, 2), 1, slice(3, 8), 2), _np.int64), # Test Ellipsis ('...') (1, Ellipsis, -1), (slice(2), Ellipsis, None, 0), diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 0db209c5774f..3d4aeb2f6d57 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -11239,3 +11239,264 @@ def forward(self, x, *args): assert_almost_equal(deconvOut, deconvRefOut) assert_almost_equal(deconvData.grad, deconvRefGrad) + + +@use_np +@pytest.mark.parametrize('dtype', np.floating_dtypes) +def test_np_finfo(dtype): + mx_finfo_obj = np.finfo(dtype) + np_finfo = onp.finfo(dtype) + assert (mx_finfo_obj.bits, mx_finfo_obj.eps, mx_finfo_obj.max, mx_finfo_obj.min, mx_finfo_obj.smallest_normal) == \ + (np_finfo.bits, np_finfo.eps, np_finfo.max, np_finfo.min, np_finfo.tiny) + + +@use_np +@pytest.mark.parametrize('dtype', np.integer_dtypes) +def test_np_iinfo(dtype): + mx_iinfo_obj = np.iinfo(dtype) + np_iinfo = onp.iinfo(dtype) + assert (mx_iinfo_obj.bits, mx_iinfo_obj.max, mx_iinfo_obj.min) == \ + (np_iinfo.bits, np_iinfo.max, np_iinfo.min) + + +@use_np +@pytest.mark.parametrize('input1', [d for d in np.numeric_dtypes + np.boolean_dtypes] + [np.ones((1,), dtype=d) for d in np.numeric_dtypes + np.boolean_dtypes]) +@pytest.mark.parametrize('input2', [d for d in np.numeric_dtypes + np.boolean_dtypes]) +def test_np_can_cast(input1, input2): + np_input1 = input1 + np_input2 = input2 + if isinstance(input1, np.ndarray): + np_input1 = input1.asnumpy() + assert np.can_cast(input1, input2) == onp.can_cast(np_input1, np_input2) + + +@use_np +@pytest.mark.parametrize('nums', [1, 2, 3, 4, 10, 100]) +def test_np_result_type(nums): + PICK_LIST = np.numeric_dtypes + np.boolean_dtypes + [np.ones((1,), dtype=d) for d in np.numeric_dtypes + np.boolean_dtypes] + import random + inputs = [random.choice(PICK_LIST) for _ in range(nums)] + + try: + promoted = np.result_type(*inputs) + except Exception as e: + with pytest.raises(TypeError): + promoted = np.result_type(*inputs) + + +@use_np +@pytest.mark.parametrize('func,func2,dtypes,ref_grad,low,high', [ + ('abs', 'abs', 'numeric', lambda x: -1. * (x < 0) + (x > 0), -1.0, 1.0), + ('acos', 'arccos', 'floating-point', lambda x: -1. / (1. - x ** 2.) ** (1. / 2.), -1.0, 1.0), + ('acosh', 'arccosh', 'floating-point', lambda x: 1./(x**2 - 1.)**(1./2.), 2.0, 5.0), + ('asin', 'arcsin', 'floating-point', lambda x: 1. / (1. - x ** 2) ** (1. / 2.), -1.0, 1.0), + ('asinh', 'arcsinh', 'floating-point', lambda x: 1./(x**2 + 1.)**(1./2.), -1.0, 1.0), + ('atan', 'arctan', 'floating-point', lambda x: 1. / (x ** 2. + 1.), -1.0, 1.0), + ('atanh', 'arctanh', 'floating-point', lambda x: -1./(x**2 - 1.), -0.99, 0.99), + ('bitwise_invert', 'invert', 'integer or boolean', None, -5, 5), + ('ceil', 'ceil', 'numeric', None, -10.0, 10.0), + ('cos', 'cos', 'floating-point', lambda x: -onp.sin(x), -1.0, 1.0), + ('cosh', 'cosh', 'floating-point', lambda x: onp.sinh(x), -1.0, 1.0), + ('exp', 'exp', 'floating-point', lambda x: onp.exp(x), -1.0, 1.0), + ('expm1', 'expm1', 'floating-point', lambda x: onp.exp(x), -1.0, 1.0), + ('floor', 'floor', 'numeric', None, -10.0, 10.0), + ('log', 'log', 'floating-point', lambda x: 1.0 / x, 0.1, 5.0), + ('log10', 'log10', 'floating-point', lambda x: 1.0 / (x * onp.log(10)), 0.1, 10.0), + ('log1p', 'log1p', 'floating-point', lambda x: 1.0 / (1.0 + x), -0.9, 5.0), + ('log2', 'log2', 'floating-point', lambda x: 1.0 / (x * onp.log(2)), 0.1, 2.0), + ('logical_not', 'logical_not', 'boolean', None, -1.0, 1.0), + ('negative', 'negative', 'numeric', lambda x: -1. * onp.ones(x.shape), -1.0, 1.0), + ('positive', 'positive', 'numeric', lambda x: onp.ones(x.shape), -1.0, 1.0), + ('sign', 'sign', 'numeric', None, -1.0, 1.0), + ('sin', 'sin', 'floating-point', lambda x: onp.cos(x), -1.0, 1.0), + ('sinh', 'sinh', 'floating-point', lambda x: onp.cosh(x), -1.0, 1.0), + ('sqrt', 'sqrt', 'floating-point', lambda x: 0.5 / onp.sqrt(x), 0.001, 10.0), + ('square', 'square', 'numeric', lambda x: 2.0 * x, -1.0, 1.0), + ('tan', 'tan', 'floating-point', lambda x: onp.tan(x) ** 2 + 1.0, -1.0, 1.0), + ('tanh', 'tanh', 'floating-point', lambda x: 1. - onp.tanh(x) ** 2, -1.0, 1.0), + ('trunc', 'trunc', 'numeric', None, -5.0, 5.0), +]) +@pytest.mark.parametrize('ndim', [2, 3, 4]) +def test_np_standard_unary_funcs(func, func2, dtypes, ref_grad, low, high, ndim): + class TestStandardUnary(HybridBlock): + def __init__(self, func): + super(TestStandardUnary, self).__init__() + self._func = func + + def forward(self, a): + return getattr(np, self._func)(a) + + type_mapping = { + 'floating-point': np.floating_dtypes, + 'numeric': np.numeric_dtypes, + 'integer or boolean': np.integer_dtypes + np.boolean_dtypes, + 'boolean': np.boolean_dtypes, + } + + def array_values(low, high, shape): + for d in np.integer_dtypes + np.boolean_dtypes + np.floating_dtypes: + yield onp.random.uniform(low, high, shape).astype(d), d + + + shapes = [i for i in [rand_shape_nd(ndim, dim=3), (1, 0, 2)]] + for shape in shapes: + for (np_test_data, dtype) in array_values(low, high, shape): + if dtype in type_mapping[dtypes]: + rtol = 1e-2 if dtype == np.float16 else 1e-3 + atol = 1e-4 if dtype == np.float16 else 1e-5 + # get rid of warning: divide by zero + if((func=='log' or func=='log10' or func=='log2') and + (dtype=='int8' or dtype=='uint8' or dtype=='int32' or + dtype=='int64')): + low = 1 + if (func=='arctanh' and dtype=='bool'): + continue + np_func = getattr(onp, func2) + mx_func = TestStandardUnary(func) + mx_test_data = np.array(np_test_data, dtype=dtype) + for hybridize in [True, False]: + if hybridize: + mx_func.hybridize() + if ref_grad: + mx_test_data.attach_grad() + np_out = np_func(np_test_data) + with mx.autograd.record(): + y = mx_func(mx_test_data) + assert y.shape == np_out.shape + assert_almost_equal(y.asnumpy(), np_out, rtol=1e-3, atol=atol) + if np_out.dtype == np.bool_: + assert y.dtype == np.bool_ + + if ref_grad and (dtype == 'float16' or dtype == 'float32' or dtype == 'float64'): + y.backward() + assert_almost_equal(mx_test_data.grad.asnumpy(), ref_grad(np_test_data), rtol=1e-1, atol=1e-2, equal_nan=True) + + np_func = getattr(onp, func2) + mx_out = getattr(mx.np, func)(mx_test_data) + assert mx_out.shape == np_out.shape + assert np.result_type(mx_out) == dtype + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=1e-5) + + assertRaises(NotImplementedError, getattr(np, func), mx_test_data, where=False) + assertRaises(NotImplementedError, getattr(np, func), mx_test_data, subok=False) + assertRaises(NotImplementedError, getattr(np, func), mx_test_data, dtype=onp.int8) + assertRaises(TypeError, getattr(np, func), mx_test_data, dtype="abcdefg") + assertRaises(NotImplementedError, getattr(np, func), mx_test_data, casting='safe') + assertRaises(TypeError, getattr(np, func), mx_test_data, casting='mxnet') + assertRaises(NotImplementedError, getattr(np, func), mx_test_data, order='C') + assertRaises(NotImplementedError, getattr(np, func), mx_test_data, order='mxnet') + + +@use_np +@pytest.mark.flaky +@pytest.mark.parametrize('func,func2,promoted,dtypes,ref_grad_a,ref_grad_b,low,high', [ + ('add', 'add', True, 'numeric', lambda y, x1, x2: onp.ones(y.shape), None, -1.0, 1.0), + ('atan2', 'arctan2', True, 'floating-point', lambda y, x1, x2: x2 / (onp.square(x1) + onp.square(x2)), + lambda y, x1, x2: -x1 / (onp.square(x1) + onp.square(x2)), -1, 1), + ('bitwise_and', 'bitwise_and', True, 'integer or boolean', None, None, -100, 100), + ('bitwise_or', 'bitwise_or', True, 'integer or boolean', None, None, -100, 100), + ('bitwise_xor', 'bitwise_xor', True, 'integer or boolean', None, None, -100, 100), + ('divide', 'divide', True, 'floating-point', lambda y, x1, x2: onp.ones(y.shape) / x2, + lambda y, x1, x2: -x1 / (x2 * x2), 0.1, 1.0), + ('equal', 'equal', False, 'all', None, None, 0.0, 2.0), + ('floor_divide', 'floor_divide', True, 'numeric', lambda y, x1, x2: onp.zeros(y.shape), + lambda y, x1, x2: onp.zeros(y.shape), 2.0, 10.0), + ('greater', 'greater', False, 'numeric', None, None, 0.0, 2.0), + ('greater_equal', 'greater_equal', False, 'numeric', None, None, 0.0, 2.0), + ('less', 'less', False, 'numeric', None, None, 0.0, 2.0), + ('less_equal', 'less_equal', False, 'numeric', None, None, 0.0, 2.0), + ('logaddexp', 'logaddexp', True, 'floating-point', lambda y, x1, x2: onp.exp(x1) / (onp.exp(x1) + onp.exp(x2)), + lambda y, x1, x2: onp.exp(x2) / (onp.exp(x1) + onp.exp(x2)), -10, 10), + ('logical_and', 'logical_and', False, 'boolean', None, None, -100, 100), + ('logical_or', 'logical_or', False, 'boolean', None, None, -100, 100), + ('logical_xor', 'logical_xor', False, 'boolean', None, None, -100, 100), + ('multiply', 'multiply', True, 'numeric', lambda y, x1, x2: onp.broadcast_to(x2, y.shape), + lambda y, x1, x2: onp.broadcast_to(x1, y.shape), -1.0, 1.0), + ('not_equal', 'not_equal', False, 'all', None, None, 0.0, 2.0), + ('pow', 'power', True, 'floating-point', lambda y, x1, x2: onp.power(x1, x2 - 1.0) * x2, + lambda y, x1, x2: onp.power(x1, x2) * onp.log(x1), 1.0, 3.0), + ('subtract', 'subtract', True, 'numeric', lambda y, x1, x2: onp.ones(y.shape), + lambda y, x1, x2: -onp.ones(y.shape), -1.0, 1.0), +]) +@pytest.mark.parametrize('lshape,rshape', [ + ((3, 2), (3, 2)), + ((3, 2), (3, 1)), + ((3, 1), (3, 0)), + ((0, 2), (1, 2)), + ((2, 3, 4), (3, 1)), + ((2, 3), ()), + ((), (2, 3)) +]) +def test_np_standard_binary_funcs(func, func2, promoted, dtypes, ref_grad_a, ref_grad_b, low, high, lshape, rshape): + class TestStandardBinary(HybridBlock): + def __init__(self, func): + super(TestStandardBinary, self).__init__() + self._func = func + + def forward(self, a, b,): + return getattr(np, self._func)(a, b) + + type_mapping = { + 'floating-point': np.floating_dtypes, + 'numeric': np.numeric_dtypes, + 'integer or boolean': np.integer_dtypes + np.boolean_dtypes, + 'boolean': np.boolean_dtypes, + 'all': np.numeric_dtypes + np.boolean_dtypes, + } + + def array_values(low, high, shape): + for d in np.integer_dtypes + np.boolean_dtypes + np.floating_dtypes: + yield onp.random.uniform(low, high, shape).astype(d), d + + + for (left_value, ltype) in array_values(low, high, lshape): + for (right_value, rtype) in array_values(low, high, rshape): + if ltype in type_mapping[dtypes] and rtype in type_mapping[dtypes]: + try: + promote_type = np.result_type(ltype, rtype) + except Exception as e: + # Unkown type promotion between two types + continue + rtol = 1e-2 if ltype == np.float16 or rtype == np.float16 else 1e-3 + atol = 1e-4 if ltype == np.float16 or rtype == np.float16 else 1e-5 + mx_left_value = np.array(left_value, dtype=ltype) + mx_right_value = np.array(right_value, dtype=rtype) + mx_func = TestStandardBinary(func) + np_func = getattr(onp, func2) + for hybridize in [True, False]: + if hybridize: + mx_func.hybridize() + if ref_grad_a: + mx_left_value.attach_grad() + mx_right_value.attach_grad() + np_out = np_func(left_value, right_value) + with mx.autograd.record(): + y = mx_func(mx_left_value, mx_right_value) + assert y.shape == np_out.shape + assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), rtol=rtol, atol=atol, + use_broadcast=False, equal_nan=True) + + if ref_grad_a and ltype in np.floating_dtypes and rtype in np.floating_dtypes: + y.backward() + assert_almost_equal(mx_left_value.grad.asnumpy(), + collapse_sum_like(ref_grad_a(y.asnumpy(), left_value, right_value), mx_left_value.shape), + rtol=1e-1, atol=1e-2, equal_nan=True, use_broadcast=False) + if ref_grad_b is None: + assert_almost_equal(mx_right_value.grad.asnumpy(), + collapse_sum_like(ref_grad_a(y.asnumpy(), right_value, left_value), mx_right_value.shape), + rtol=1e-1, atol=1e-2, equal_nan=True, use_broadcast=False) + else: + assert_almost_equal(mx_right_value.grad.asnumpy(), + collapse_sum_like(ref_grad_b(y.asnumpy(), left_value, right_value), mx_right_value.shape), + rtol=1e-1, atol=1e-2, equal_nan=True, use_broadcast=False) + + np_out = getattr(onp, func2)(left_value, right_value) + mx_out = getattr(np, func)(mx_left_value, mx_right_value) + assert mx_out.shape == np_out.shape + if promoted: + assert np.result_type(ltype, rtype) == mx_out.dtype + else: + assert mx_out.dtype == np.bool_ + assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), rtol=rtol, atol=atol, + use_broadcast=False, equal_nan=True) +