diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 3b9786408df9..d6b5595036ad 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -29,7 +29,7 @@ from mxnet.test_utils import assert_almost_equal from mxnet.test_utils import use_np from mxnet.test_utils import is_op_runnable -from common import assertRaises, with_seed, random_seed +from common import assertRaises, with_seed, random_seed, setup_module, teardown_module from mxnet.numpy_dispatch_protocol import with_array_function_protocol, with_array_ufunc_protocol from mxnet.numpy_dispatch_protocol import _NUMPY_ARRAY_FUNCTION_LIST, _NUMPY_ARRAY_UFUNC_LIST @@ -62,8 +62,15 @@ def add_workload(name, *args, **kwargs): @staticmethod def get_workloads(name): + if OpArgMngr._args == {}: + _prepare_workloads() return OpArgMngr._args.get(name, None) + @staticmethod + def randomize_workloads(): + # Force a new _prepare_workloads(), which will be based on new random numbers + OpArgMngr._args = {} + def _add_workload_all(): # check bad element in all positions @@ -516,8 +523,8 @@ def _add_workload_linalg_cholesky(): shapes = [(1, 1), (2, 2), (3, 3), (50, 50), (3, 10, 10)] dtypes = (np.float32, np.float64) - for shape, dtype in itertools.product(shapes, dtypes): - with random_seed(1): + with random_seed(1): + for shape, dtype in itertools.product(shapes, dtypes): a = _np.random.randn(*shape) t = list(range(len(shape))) @@ -3183,9 +3190,6 @@ def _prepare_workloads(): _add_workload_vander() -_prepare_workloads() - - def _get_numpy_op_output(onp_op, *args, **kwargs): onp_args = [arg.asnumpy() if isinstance(arg, np.ndarray) else arg for arg in args] onp_kwargs = {k: v.asnumpy() if isinstance(v, np.ndarray) else v for k, v in kwargs.items()} @@ -3197,7 +3201,7 @@ def _get_numpy_op_output(onp_op, *args, **kwargs): return onp_op(*onp_args, **onp_kwargs) -def _check_interoperability_helper(op_name, *args, **kwargs): +def _check_interoperability_helper(op_name, rel_tol, abs_tol, *args, **kwargs): strs = op_name.split('.') if len(strs) == 1: onp_op = getattr(_np, op_name) @@ -3213,11 +3217,11 @@ def _check_interoperability_helper(op_name, *args, **kwargs): assert type(out) == type(expected_out) for arr, expected_arr in zip(out, expected_out): if isinstance(arr, np.ndarray): - assert_almost_equal(arr.asnumpy(), expected_arr, rtol=1e-3, atol=1e-4, use_broadcast=False, equal_nan=True) + assert_almost_equal(arr.asnumpy(), expected_arr, rtol=rel_tol, atol=abs_tol, use_broadcast=False, equal_nan=True) else: _np.testing.assert_equal(arr, expected_arr) elif isinstance(out, np.ndarray): - assert_almost_equal(out.asnumpy(), expected_out, rtol=1e-3, atol=1e-4, use_broadcast=False, equal_nan=True) + assert_almost_equal(out.asnumpy(), expected_out, rtol=rel_tol, atol=abs_tol, use_broadcast=False, equal_nan=True) elif isinstance(out, _np.dtype): _np.testing.assert_equal(out, expected_out) else: @@ -3229,6 +3233,7 @@ def _check_interoperability_helper(op_name, *args, **kwargs): def check_interoperability(op_list): + OpArgMngr.randomize_workloads() for name in op_list: if name in _TVM_OPS and not is_op_runnable(): continue @@ -3240,13 +3245,17 @@ def check_interoperability(op_list): if name in ['full_like', 'zeros_like', 'ones_like'] and \ StrictVersion(platform.python_version()) < StrictVersion('3.0.0'): continue + default_tols = (1e-3, 1e-4) + tols = {'linalg.tensorinv': (1e-2, 5e-3), + 'linalg.solve': (1e-3, 5e-2)} + (rel_tol, abs_tol) = tols.get(name, default_tols) print('Dispatch test:', name) workloads = OpArgMngr.get_workloads(name) assert workloads is not None, 'Workloads for operator `{}` has not been ' \ 'added for checking interoperability with ' \ 'the official NumPy.'.format(name) for workload in workloads: - _check_interoperability_helper(name, *workload['args'], **workload['kwargs']) + _check_interoperability_helper(name, rel_tol, abs_tol, *workload['args'], **workload['kwargs']) @with_seed()