diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index b68765eedff8..977e3e07c96d 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -272,6 +272,17 @@ MXNET_DLL int MXRandomSeed(int seed); */ MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id); +/*! + * \brief Change floating-point calculations when dealing with denormalized values. + * Currently this option is only supported in CPU backend. + * Flushing denormalized values to zero is enabled by default. + * + * \param value state of flush-to-zero and denormals-are-zero to set. + * \param prev_state state of flush-to-zero and denormals-are-zero before setting new state. + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXSetFlushDenorms(bool value, bool* prev_state); + /*! * \brief Notify the engine about a shutdown, * This can help engine to print less messages into display. diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 1f9f37d04d88..2e8d4b484318 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -311,6 +311,8 @@ def _load_lib(): # library instance of mxnet _LIB = _load_lib() +check_call(_LIB.MXSetFlushDenorms(ctypes.c_bool(True), + ctypes.byref(ctypes.c_bool()))) # type definitions mx_int = ctypes.c_int mx_uint = ctypes.c_uint diff --git a/python/mxnet/util.py b/python/mxnet/util.py index cafff0f9dd9e..ea75030614be 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -1200,3 +1200,27 @@ def get_rtc_compile_opts(ctx): arch_opt = "--gpu-architecture={}_{}".format("sm" if should_compile_to_SASS else "compute", device_cc_as_used) return [arch_opt] + +def set_flush_denorms(value): + """Change floating-point calculations on CPU when dealing with denormalized values. + This is only applicable to architectures which supports flush-to-zero. + Denormalized values are positive and negative values that are very close to 0 + (exponent is the smallest possible value). + Flushing denormalized values to 0 can speedup calculations if such values occurs, + but if fulfilling whole IEEE 754 standard is required this option should be disabled. + Flushing denormalized values is enabled in MXNet by default. + + Parameters + ---------- + value : bool + State of flush-to-zero and denormals-are-zero in MXCSR register + + Returns + ------- + prev_state : bool + Previous state of flush-to-zero in MXCSR register + """ + ret = ctypes.c_bool() + passed_value = ctypes.c_bool(value) + check_call(_LIB.MXSetFlushDenorms(passed_value, ctypes.byref(ret))) + return ret.value diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index eac1944016df..c54cc0e6f470 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -62,6 +62,23 @@ #include "miniz.h" #include "nnvm/pass_functions.h" +// FTZ only applies to SSE and AVX instructions. +#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 1) +#define SUPPORT_FTZ_DMZ 1 +#else +#define SUPPORT_FTZ_DMZ 0 +#endif + +#if SUPPORT_FTZ_DMZ +#include +#include +#endif +#if SUPPORT_FTZ_DMZ && !defined(_MSC_VER) +#include +#endif + + using namespace mxnet; // Internal function to get the information @@ -1587,6 +1604,52 @@ int MXRandomSeedContext(int seed, int dev_type, int dev_id) { API_END(); } +int MXSetFlushDenorms(bool value, bool* prev_state) { + API_BEGIN(); + *prev_state = false; + + #if SUPPORT_FTZ_DMZ + std::function is_dmz_flag_available = []() { + // Intel 64 and IA-32 Architectures Software Developer’s Manual: Vol. 1 + // "Checking for the DAZ Flag in the MXCSR Register" + constexpr unsigned int mxcsr_mask_offset = 28; + constexpr unsigned int dmz_flag_offset = 5; + constexpr unsigned int fxsave_req_bytes = 512; + + char* fxsave_area_ptr = reinterpret_cast(malloc(fxsave_req_bytes)); + memset(fxsave_area_ptr, 0, fxsave_req_bytes); // fill memory with 0 + _fxsave(fxsave_area_ptr); + + char* mxcsr_mask_ptr = fxsave_area_ptr + mxcsr_mask_offset; + uint32_t mxcsr_mask = *(reinterpret_cast((mxcsr_mask_ptr))); + // DMZ flag is supported if sixth bit of MXCSR_MASK is hot + bool dmz_flag = (mxcsr_mask >> dmz_flag_offset) & 0x1; + free(fxsave_area_ptr); + return dmz_flag; + }; + + Engine::Get()->PushSync( + [value, prev_state, is_dmz_flag_available](RunContext rctx) { + const unsigned int DMZ_STATE = value ? _MM_DENORMALS_ZERO_ON : _MM_DENORMALS_ZERO_OFF; + const unsigned int FTZ_STATE = value ? _MM_FLUSH_ZERO_ON : _MM_FLUSH_ZERO_OFF; + *prev_state = _MM_GET_FLUSH_ZERO_MODE(); + _MM_SET_FLUSH_ZERO_MODE(FTZ_STATE); + + // If the DAZ flag is not supported, then it is a reserved bit and attempting to write a 1 + // to it will cause a general-protection exception (#GP) + if (is_dmz_flag_available()) { + _MM_SET_DENORMALS_ZERO_MODE(DMZ_STATE); + } + }, Context::CPU(), {}, {}, + FnProperty::kNormal, 0, "SetFlushDenorms"); + + Engine::Get()->WaitForAll(); + + #endif + + API_END(); +} + int MXNotifyShutdown() { API_BEGIN(); mxnet::op::custom::CustomOperator::Get()->Stop(); diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 1fa7d5284399..1b8fe4d132c9 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -25,7 +25,7 @@ import numpy as _np import unittest import pytest -from mxnet import np +from mxnet import np, util from mxnet.test_utils import assert_almost_equal from mxnet.test_utils import use_np from mxnet.test_utils import is_op_runnable @@ -3341,7 +3341,11 @@ def test_np_array_function_protocol(): @with_array_ufunc_protocol @pytest.mark.serial def test_np_array_ufunc_protocol(): - check_interoperability(_NUMPY_ARRAY_UFUNC_LIST) + prev_state = util.set_flush_denorms(False) + try: + check_interoperability(_NUMPY_ARRAY_UFUNC_LIST) + finally: + util.set_flush_denorms(prev_state) @use_np