diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 527737e03cd7..8c19a9358976 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -328,12 +328,11 @@ def check_quantized_elemwise_add(data_shape, qtype): elemwise_add_int8_exe.arg_dict[qarg_names[4]][:] = data_low elemwise_add_int8_exe.arg_dict[qarg_names[5]][:] = data_high qoutput, min_range, max_range = elemwise_add_int8_exe.forward() - min_val = min_range.asnumpy().tolist()[0] - max_val = max_range.asnumpy().tolist()[0] - fp32_rslt = output.asnumpy() - int8_rslt = qoutput.asnumpy()*max_val/0x7fffffff - assert_almost_equal(int8_rslt, int8_rslt, atol = 1e-4) + int8_rslt = qoutput.astype(output.dtype)*max_range/0x7fffffff + diff = mx.nd.abs(output - int8_rslt) + cond = mx.nd.lesser(2, diff).sum().asscalar() + assert cond == 0 for qtype in ['int8', 'uint8']: check_quantized_elemwise_add((4, 6), qtype)