This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
test_np_round #19081
Copy link
Copy link
Open
Labels
Description
[2020-09-02T00:02:22.521Z] ================================== FAILURES ===================================
[2020-09-02T00:02:22.521Z] ________________________________ test_np_round ________________________________
[2020-09-02T00:02:22.521Z] [gw3] win32 -- Python 3.7.3 C:\Python37\python.exe
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] @with_seed()
[2020-09-02T00:02:22.521Z] @use_np
[2020-09-02T00:02:22.521Z] def test_np_round():
[2020-09-02T00:02:22.521Z] class TestRound(HybridBlock):
[2020-09-02T00:02:22.521Z] def __init__(self, func, decimals):
[2020-09-02T00:02:22.521Z] super(TestRound, self).__init__()
[2020-09-02T00:02:22.521Z] self.func = func
[2020-09-02T00:02:22.521Z] self.decimals = decimals
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] def hybrid_forward(self, F, x):
[2020-09-02T00:02:22.521Z] return getattr(F.np, self.func)(x, self.decimals)
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] shapes = [(), (1, 2, 3), (1, 0)]
[2020-09-02T00:02:22.521Z] types = ['int32', 'int64', 'float32', 'float64']
[2020-09-02T00:02:22.521Z] funcs = ['round', 'round_']
[2020-09-02T00:02:22.521Z] for hybridize, oneType, func in itertools.product([True, False], types, funcs):
[2020-09-02T00:02:22.521Z] rtol, atol = 1e-3, 1e-5
[2020-09-02T00:02:22.521Z] for shape in shapes:
[2020-09-02T00:02:22.521Z] for d in range(-5, 6):
[2020-09-02T00:02:22.521Z] test_round = TestRound(func, d)
[2020-09-02T00:02:22.521Z] if hybridize:
[2020-09-02T00:02:22.521Z] test_round.hybridize()
[2020-09-02T00:02:22.521Z] x = rand_ndarray(shape, dtype=oneType).as_np_ndarray()
[2020-09-02T00:02:22.521Z] np_out = getattr(_np, func)(x.asnumpy(), d)
[2020-09-02T00:02:22.521Z] mx_out = test_round(x)
[2020-09-02T00:02:22.521Z] assert mx_out.shape == np_out.shape
[2020-09-02T00:02:22.521Z] > assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] tests\python\unittest\test_numpy_op.py:7691:
[2020-09-02T00:02:22.521Z] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] a = array([[[ 1.029 , -0.371 , 1.11 ],
[2020-09-02T00:02:22.521Z] [ 1.4289999, -0.636 , -0.306 ]]], dtype=float32)
[2020-09-02T00:02:22.521Z] b = array([[[ 1.029, -0.371, 1.11 ],
[2020-09-02T00:02:22.521Z] [ 1.429, -0.636, -0.307]]], dtype=float32)
[2020-09-02T00:02:22.521Z] rtol = 0.001, atol = 1e-05, names = ('a', 'b'), equal_nan = False
[2020-09-02T00:02:22.521Z] use_broadcast = True, mismatches = (10, 10)
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=False,
[2020-09-02T00:02:22.521Z] use_broadcast=True, mismatches=(10, 10)):
[2020-09-02T00:02:22.521Z] """Test that two numpy arrays are almost equal. Raise exception message if not.
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] Parameters
[2020-09-02T00:02:22.521Z] ----------
[2020-09-02T00:02:22.521Z] a : np.ndarray or mx.nd.array
[2020-09-02T00:02:22.521Z] b : np.ndarray or mx.nd.array
[2020-09-02T00:02:22.521Z] rtol : None or float or dict of dtype -> float
[2020-09-02T00:02:22.521Z] The relative threshold. Default threshold will be used if set to ``None``.
[2020-09-02T00:02:22.521Z] atol : None or float or dict of dtype -> float
[2020-09-02T00:02:22.521Z] The absolute threshold. Default threshold will be used if set to ``None``.
[2020-09-02T00:02:22.521Z] names : tuple of names, optional
[2020-09-02T00:02:22.521Z] The names used in error message when an exception occurs
[2020-09-02T00:02:22.521Z] equal_nan : boolean, optional
[2020-09-02T00:02:22.521Z] The flag determining how to treat NAN values in comparison
[2020-09-02T00:02:22.521Z] mismatches : tuple of mismatches
[2020-09-02T00:02:22.521Z] Maximum number of mismatches to be printed (mismatches[0]) and determine (mismatches[1])
[2020-09-02T00:02:22.521Z] """
[2020-09-02T00:02:22.521Z] if not use_broadcast:
[2020-09-02T00:02:22.521Z] checkShapes(a, b)
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] rtol, atol = get_tols(a, b, rtol, atol)
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] if isinstance(a, mx.numpy.ndarray):
[2020-09-02T00:02:22.521Z] a = a.asnumpy()
[2020-09-02T00:02:22.521Z] if isinstance(b, mx.numpy.ndarray):
[2020-09-02T00:02:22.521Z] b = b.asnumpy()
[2020-09-02T00:02:22.521Z] use_np_allclose = isinstance(a, np.ndarray) and isinstance(b, np.ndarray)
[2020-09-02T00:02:22.521Z] if not use_np_allclose:
[2020-09-02T00:02:22.521Z] if not (hasattr(a, 'ctx') and hasattr(b, 'ctx') and a.ctx == b.ctx and a.dtype == b.dtype):
[2020-09-02T00:02:22.521Z] use_np_allclose = True
[2020-09-02T00:02:22.521Z] if isinstance(a, mx.nd.NDArray):
[2020-09-02T00:02:22.521Z] a = a.asnumpy()
[2020-09-02T00:02:22.521Z] if isinstance(b, mx.nd.NDArray):
[2020-09-02T00:02:22.521Z] b = b.asnumpy()
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] if use_np_allclose:
[2020-09-02T00:02:22.521Z] if hasattr(a, 'dtype') and a.dtype == np.bool_ and hasattr(b, 'dtype') and b.dtype == np.bool_:
[2020-09-02T00:02:22.521Z] np.testing.assert_equal(a, b)
[2020-09-02T00:02:22.521Z] return
[2020-09-02T00:02:22.521Z] if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
[2020-09-02T00:02:22.521Z] return
[2020-09-02T00:02:22.521Z] else:
[2020-09-02T00:02:22.521Z] output = mx.nd.contrib.allclose(a, b, rtol, atol, equal_nan)
[2020-09-02T00:02:22.521Z] if output.asnumpy() == 1:
[2020-09-02T00:02:22.521Z] return
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] a = a.asnumpy()
[2020-09-02T00:02:22.521Z] b = b.asnumpy()
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] index, rel = _find_max_violation(a, b, rtol, atol)
[2020-09-02T00:02:22.521Z] if index != ():
[2020-09-02T00:02:22.521Z] # a, b are the numpy arrays
[2020-09-02T00:02:22.521Z] indexErr = index
[2020-09-02T00:02:22.521Z] relErr = rel
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] print('\n*** Maximum errors for vector of size {}: rtol={}, atol={}\n'.format(a.size, rtol, atol))
[2020-09-02T00:02:22.521Z] aTmp = a.copy()
[2020-09-02T00:02:22.521Z] bTmp = b.copy()
[2020-09-02T00:02:22.521Z] i = 1
[2020-09-02T00:02:22.521Z] while i <= a.size:
[2020-09-02T00:02:22.521Z] if i <= mismatches[0]:
[2020-09-02T00:02:22.521Z] print("%3d: Error %f %s" %(i, rel, locationError(a, b, index, names)))
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] aTmp[index] = bTmp[index] = 0
[2020-09-02T00:02:22.521Z] if almost_equal(aTmp, bTmp, rtol, atol, equal_nan=equal_nan):
[2020-09-02T00:02:22.521Z] break
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] i += 1
[2020-09-02T00:02:22.521Z] if i <= mismatches[1] or mismatches[1] <= 0:
[2020-09-02T00:02:22.521Z] index, rel = _find_max_violation(aTmp, bTmp, rtol, atol)
[2020-09-02T00:02:22.521Z] else:
[2020-09-02T00:02:22.521Z] break
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] mismatchDegree = "at least " if mismatches[1] > 0 and i > mismatches[1] else ""
[2020-09-02T00:02:22.521Z] errMsg = "Error %f exceeds tolerance rtol=%e, atol=%e (mismatch %s%f%%).\n%s" % \
[2020-09-02T00:02:22.521Z] (relErr, rtol, atol, mismatchDegree, 100*i/a.size, \
[2020-09-02T00:02:22.521Z] locationError(a, b, indexErr, names, maxError=True))
[2020-09-02T00:02:22.521Z] else:
[2020-09-02T00:02:22.521Z] errMsg = "Error %f exceeds tolerance rtol=%e, atol=%e.\n" % (rel, rtol, atol)
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] np.set_printoptions(threshold=4, suppress=True)
[2020-09-02T00:02:22.521Z] msg = npt.build_err_msg([a, b], err_msg=errMsg)
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] > raise AssertionError(msg)
[2020-09-02T00:02:22.521Z] E AssertionError:
[2020-09-02T00:02:22.521Z] E Items are not equal:
[2020-09-02T00:02:22.521Z] E Error 3.154627 exceeds tolerance rtol=1.000000e-03, atol=1.000000e-05 (mismatch 16.666667%).
[2020-09-02T00:02:22.521Z] E Location of maximum error: (0, 1, 2), a=-0.30599999, b=-0.30700001
[2020-09-02T00:02:22.521Z] E ACTUAL: array([[[ 1.029 , -0.371 , 1.11 ],
[2020-09-02T00:02:22.521Z] E [ 1.4289999, -0.636 , -0.306 ]]], dtype=float32)
[2020-09-02T00:02:22.521Z] E DESIRED: array([[[ 1.029, -0.371, 1.11 ],
[2020-09-02T00:02:22.521Z] E [ 1.429, -0.636, -0.307]]], dtype=float32)
[2020-09-02T00:02:22.521Z]
[2020-09-02T00:02:22.521Z] windows_package\python\mxnet\test_utils.py:735: AssertionError