From f4c127d1b95cf10bc200bd95ad3000f748e247dd Mon Sep 17 00:00:00 2001 From: chenxiny Date: Wed, 25 Dec 2019 11:18:29 +0800 Subject: [PATCH] fix int8 add ut --- tests/python/quantization/test_quantization.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 527737e03cd7..8c19a9358976 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -328,12 +328,11 @@ def check_quantized_elemwise_add(data_shape, qtype): elemwise_add_int8_exe.arg_dict[qarg_names[4]][:] = data_low elemwise_add_int8_exe.arg_dict[qarg_names[5]][:] = data_high qoutput, min_range, max_range = elemwise_add_int8_exe.forward() - min_val = min_range.asnumpy().tolist()[0] - max_val = max_range.asnumpy().tolist()[0] - fp32_rslt = output.asnumpy() - int8_rslt = qoutput.asnumpy()*max_val/0x7fffffff - assert_almost_equal(int8_rslt, int8_rslt, atol = 1e-4) + int8_rslt = qoutput.astype(output.dtype)*max_range/0x7fffffff + diff = mx.nd.abs(output - int8_rslt) + cond = mx.nd.lesser(2, diff).sum().asscalar() + assert cond == 0 for qtype in ['int8', 'uint8']: check_quantized_elemwise_add((4, 6), qtype)