Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from mxnet.test_utils import assert_almost_equal
from mxnet.test_utils import use_np
from mxnet.test_utils import is_op_runnable
from common import assertRaises, with_seed, random_seed
from common import assertRaises, with_seed, random_seed, setup_module, teardown_module
from mxnet.numpy_dispatch_protocol import with_array_function_protocol, with_array_ufunc_protocol
from mxnet.numpy_dispatch_protocol import _NUMPY_ARRAY_FUNCTION_LIST, _NUMPY_ARRAY_UFUNC_LIST

Expand Down Expand Up @@ -62,8 +62,15 @@ def add_workload(name, *args, **kwargs):

@staticmethod
def get_workloads(name):
if OpArgMngr._args == {}:
_prepare_workloads()
return OpArgMngr._args.get(name, None)

@staticmethod
def randomize_workloads():
# Force a new _prepare_workloads(), which will be based on new random numbers
OpArgMngr._args = {}


def _add_workload_all():
# check bad element in all positions
Expand Down Expand Up @@ -516,8 +523,8 @@ def _add_workload_linalg_cholesky():
shapes = [(1, 1), (2, 2), (3, 3), (50, 50), (3, 10, 10)]
dtypes = (np.float32, np.float64)

for shape, dtype in itertools.product(shapes, dtypes):
with random_seed(1):
with random_seed(1):
for shape, dtype in itertools.product(shapes, dtypes):
a = _np.random.randn(*shape)

t = list(range(len(shape)))
Expand Down Expand Up @@ -3183,9 +3190,6 @@ def _prepare_workloads():
_add_workload_vander()


_prepare_workloads()


def _get_numpy_op_output(onp_op, *args, **kwargs):
onp_args = [arg.asnumpy() if isinstance(arg, np.ndarray) else arg for arg in args]
onp_kwargs = {k: v.asnumpy() if isinstance(v, np.ndarray) else v for k, v in kwargs.items()}
Expand All @@ -3197,7 +3201,7 @@ def _get_numpy_op_output(onp_op, *args, **kwargs):
return onp_op(*onp_args, **onp_kwargs)


def _check_interoperability_helper(op_name, *args, **kwargs):
def _check_interoperability_helper(op_name, rel_tol, abs_tol, *args, **kwargs):
strs = op_name.split('.')
if len(strs) == 1:
onp_op = getattr(_np, op_name)
Expand All @@ -3213,11 +3217,11 @@ def _check_interoperability_helper(op_name, *args, **kwargs):
assert type(out) == type(expected_out)
for arr, expected_arr in zip(out, expected_out):
if isinstance(arr, np.ndarray):
assert_almost_equal(arr.asnumpy(), expected_arr, rtol=1e-3, atol=1e-4, use_broadcast=False, equal_nan=True)
assert_almost_equal(arr.asnumpy(), expected_arr, rtol=rel_tol, atol=abs_tol, use_broadcast=False, equal_nan=True)
else:
_np.testing.assert_equal(arr, expected_arr)
elif isinstance(out, np.ndarray):
assert_almost_equal(out.asnumpy(), expected_out, rtol=1e-3, atol=1e-4, use_broadcast=False, equal_nan=True)
assert_almost_equal(out.asnumpy(), expected_out, rtol=rel_tol, atol=abs_tol, use_broadcast=False, equal_nan=True)
elif isinstance(out, _np.dtype):
_np.testing.assert_equal(out, expected_out)
else:
Expand All @@ -3229,6 +3233,7 @@ def _check_interoperability_helper(op_name, *args, **kwargs):


def check_interoperability(op_list):
OpArgMngr.randomize_workloads()
for name in op_list:
if name in _TVM_OPS and not is_op_runnable():
continue
Expand All @@ -3240,13 +3245,17 @@ def check_interoperability(op_list):
if name in ['full_like', 'zeros_like', 'ones_like'] and \
StrictVersion(platform.python_version()) < StrictVersion('3.0.0'):
continue
default_tols = (1e-3, 1e-4)
tols = {'linalg.tensorinv': (1e-2, 5e-3),
'linalg.solve': (1e-3, 5e-2)}
(rel_tol, abs_tol) = tols.get(name, default_tols)
print('Dispatch test:', name)
workloads = OpArgMngr.get_workloads(name)
assert workloads is not None, 'Workloads for operator `{}` has not been ' \
'added for checking interoperability with ' \
'the official NumPy.'.format(name)
for workload in workloads:
_check_interoperability_helper(name, *workload['args'], **workload['kwargs'])
_check_interoperability_helper(name, rel_tol, abs_tol, *workload['args'], **workload['kwargs'])


@with_seed()
Expand Down