From 441a522908a9d6b20d1455ce019d8cd86ab7ab06 Mon Sep 17 00:00:00 2001 From: mozga-intel Date: Tue, 17 Aug 2021 09:05:54 +0200 Subject: [PATCH] Test tak, add additional axis --- tests/python/unittest/test_operator.py | 72 +++++++++++++------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 0e07c3782922..cbae11e5b452 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4188,47 +4188,47 @@ def grad_helper(grad_in, axis, idx): for _ in range(idx_ndim): idx_shape += (np.random.randint(low=1, high=5), ) - data = mx.sym.Variable('a') - idx = mx.sym.Variable('indices') - idx = mx.sym.BlockGrad(idx) - result = mx.sym.take(a=data, indices=idx, axis=axis, mode=mode) - exe = result._simple_bind(default_context(), a=data_shape, - indices=idx_shape) - data_real = np.random.normal(size=data_shape).astype('float32') - if out_of_range: - idx_real = np.random.randint(low=-data_shape[axis], high=data_shape[axis], size=idx_shape) - if mode == 'raise': - idx_real[idx_real == 0] = 1 - idx_real *= data_shape[axis] - else: - idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape) - if axis < 0: - axis += len(data_shape) + data = mx.sym.Variable('a') + idx = mx.sym.Variable('indices') + idx = mx.sym.BlockGrad(idx) + result = mx.sym.take(a=data, indices=idx, axis=axis, mode=mode) + exe = result._simple_bind(default_context(), a=data_shape, + indices=idx_shape) + data_real = np.random.normal(size=data_shape).astype('float32') + if out_of_range: + idx_real = np.random.randint(low=-data_shape[axis], high=data_shape[axis], size=idx_shape) + if mode == 'raise': + idx_real[idx_real == 0] = 1 + idx_real *= data_shape[axis] + else: + idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape) + if axis < 0: + axis += len(data_shape) - grad_out = np.ones((data_shape[0:axis] if axis > 0 else ()) + idx_shape + (data_shape[axis+1:] if axis < len(data_shape) - 1 else ()), dtype='float32') - grad_in = np.zeros(data_shape, dtype='float32') + grad_out = np.ones((data_shape[0:axis] if axis > 0 else ()) + idx_shape + (data_shape[axis+1:] if axis < len(data_shape) - 1 else ()), dtype='float32') + grad_in = np.zeros(data_shape, dtype='float32') - exe.arg_dict['a'][:] = mx.nd.array(data_real) - exe.arg_dict['indices'][:] = mx.nd.array(idx_real) - exe.forward(is_train=True) - if out_of_range and mode == 'raise': - try: - mx_out = exe.outputs[0].asnumpy() - except MXNetError as e: - return - else: - # Did not raise exception - assert False, "did not raise %s" % MXNetError.__name__ + exe.arg_dict['a'][:] = mx.nd.array(data_real) + exe.arg_dict['indices'][:] = mx.nd.array(idx_real) + exe.forward(is_train=True) + if out_of_range and mode == 'raise': + try: + mx_out = exe.outputs[0].asnumpy() + except MXNetError as e: + return + else: + # Did not raise exception + assert False, "did not raise %s" % MXNetError.__name__ - assert_almost_equal(exe.outputs[0], np.take(data_real, idx_real, axis=axis, mode=mode)) + assert_almost_equal(exe.outputs[0], np.take(data_real, idx_real, axis=axis, mode=mode)) - for i in np.nditer(idx_real): - if mode == 'clip': - i = np.clip(i, 0, data_shape[axis]) - grad_helper(grad_in, axis, i) + for i in np.nditer(idx_real): + if mode == 'clip': + i = np.clip(i, 0, data_shape[axis]) + grad_helper(grad_in, axis, i) - exe.backward([mx.nd.array(grad_out)]) - assert_almost_equal(exe.grad_dict['a'], grad_in) + exe.backward([mx.nd.array(grad_out)]) + assert_almost_equal(exe.grad_dict['a'], grad_in) def test_grid_generator():