diff --git a/tests/python/unittest/test_executor.py b/tests/python/unittest/test_executor.py index 630cad87496d..3117f6646481 100644 --- a/tests/python/unittest/test_executor.py +++ b/tests/python/unittest/test_executor.py @@ -18,13 +18,7 @@ import numpy as np import mxnet as mx from common import setup_module, with_seed, teardown - - -def reldiff(a, b): - diff = np.sum(np.abs(a - b)) - norm = np.sum(np.abs(a)) - reldiff = diff / norm - return reldiff +from mxnet.test_utils import assert_almost_equal def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None): @@ -64,9 +58,9 @@ def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None): out1 = uf(lhs_arr.asnumpy(), rhs_arr.asnumpy()) out3 = exec3.outputs[0].asnumpy() out4 = exec4.outputs[0].asnumpy() - assert reldiff(out1, out2) < 1e-6 - assert reldiff(out1, out3) < 1e-6 - assert reldiff(out1, out4) < 1e-6 + assert_almost_equal(out1, out2, rtol=1e-5, atol=1e-5) + assert_almost_equal(out1, out3, rtol=1e-5, atol=1e-5) + assert_almost_equal(out1, out4, rtol=1e-5, atol=1e-5) # test gradient out_grad = mx.nd.array(np.ones(out2.shape)) lhs_grad2, rhs_grad2 = gf(out_grad.asnumpy(), @@ -74,8 +68,8 @@ def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None): rhs_arr.asnumpy()) executor.backward([out_grad]) - assert reldiff(lhs_grad.asnumpy(), lhs_grad2) < 1e-6 - assert reldiff(rhs_grad.asnumpy(), rhs_grad2) < 1e-6 + assert_almost_equal(lhs_grad.asnumpy(), lhs_grad2, rtol=1e-5, atol=1e-5) + assert_almost_equal(rhs_grad.asnumpy(), rhs_grad2, rtol=1e-5, atol=1e-5) @with_seed(0) @@ -118,12 +112,14 @@ def check_bind(disable_bulk_exec): check_bind(False) -@with_seed(0) +# @roywei: Removing fixed seed as flakiness in this test is fixed +# tracked at https://github.com/apache/incubator-mxnet/issues/11686 +@with_seed() def test_dot(): nrepeat = 10 maxdim = 4 for repeat in range(nrepeat): - s =tuple(np.random.randint(1, 500, size=3)) + s =tuple(np.random.randint(1, 200, size=3)) check_bind_with_uniform(lambda x, y: np.dot(x, y), lambda g, x, y: (np.dot(g, y.T), np.dot(x.T, g)), 2, @@ -131,7 +127,7 @@ def test_dot(): rshape=(s[1], s[2]), sf = mx.symbol.dot) for repeat in range(nrepeat): - s =tuple(np.random.randint(1, 500, size=1)) + s =tuple(np.random.randint(1, 200, size=1)) check_bind_with_uniform(lambda x, y: np.dot(x, y), lambda g, x, y: (g * y, g * x), 2,