diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py index 82b5d319f92f..6ca78b5c0806 100644 --- a/nnvm/tests/python/frontend/onnx/test_forward.py +++ b/nnvm/tests/python/frontend/onnx/test_forward.py @@ -272,7 +272,7 @@ def test_slice(): _test_slice_iteration(x, x[:, 1:1000], (1), (1000), (1)) _test_slice_iteration(x, x[:, 0:-1], (0), (-1), (1)) -def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs): +def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs, rtol=1e-7, atol=1e-7): indata = np.random.uniform(-1, 1, size=inshape).astype(dtype) outdata = outfunc(indata, **npargs) @@ -290,7 +290,7 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs): for target, ctx in ctx_list(): tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype) - tvm.testing.assert_allclose(outdata, tvm_out) + tvm.testing.assert_allclose(outdata, tvm_out, rtol=rtol, atol=atol) def test_floor(): _test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, 'float32', 'Floor', {}) @@ -863,7 +863,7 @@ def test_binary_ops(): dtype = "float32" out_shape = in_shape - def verify_binary_ops(op, x, y, out_np, broadcast=None): + def verify_binary_ops(op, x, y, out_np, broadcast=None, rtol=1e-7, atol=1e-7): if broadcast is None: z = helper.make_node(op, ['in1', 'in2'], ['out']) else: @@ -879,7 +879,7 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None): model = helper.make_model(graph, producer_name='_test') for target, ctx in ctx_list(): tvm_out = get_tvm_output(model, [x, y], target, ctx) - tvm.testing.assert_allclose(out_np, tvm_out) + tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol) x = np.random.uniform(size=in_shape).astype(dtype) y = np.random.uniform(size=in_shape).astype(dtype) @@ -890,8 +890,8 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None): verify_binary_ops("Sub", x, z, x - z, broadcast=True) verify_binary_ops("Mul",x, y, x * y, broadcast=None) verify_binary_ops("Mul", x, z, x * z, broadcast=True) - verify_binary_ops("Div", x, y, x / y, broadcast=None) - verify_binary_ops("Div", x, z, x / z, broadcast=True) + verify_binary_ops("Div", x, y, x / y, broadcast=None, rtol=1e-5, atol=1e-5) + verify_binary_ops("Div", x, z, x / z, broadcast=True, rtol=1e-5, atol=1e-5) verify_binary_ops("Sum", x, y, x + y, broadcast=None) def test_single_ops(): @@ -899,7 +899,7 @@ def test_single_ops(): dtype = "float32" out_shape = in_shape - def verify_single_ops(op, x, out_np): + def verify_single_ops(op, x, out_np, rtol=1e-7, atol=1e-7): z = helper.make_node(op, ['in1'], ['out']) graph = helper.make_graph([z], '_test', @@ -915,8 +915,8 @@ def verify_single_ops(op, x, out_np): x = np.random.uniform(size=in_shape).astype(dtype) verify_single_ops("Neg",x, -x) verify_single_ops("Abs",x, np.abs(x)) - verify_single_ops("Reciprocal",x, 1/x) - verify_single_ops("Sqrt",x, np.sqrt(x)) + verify_single_ops("Reciprocal",x, 1/x, rtol=1e-5, atol=1e-5) + verify_single_ops("Sqrt",x, np.sqrt(x), rtol=1e-5, atol=1e-5) verify_single_ops("Relu",x, np.maximum(x, 0)) verify_single_ops("Exp",x, np.exp(x)) verify_single_ops("Log",x, np.log(x)) @@ -1004,7 +1004,9 @@ def test_LogSoftmax(): {}, 'float32', 'LogSoftmax', - {'axis': 1}) + {'axis': 1}, + rtol=1e-5, + atol=1e-5) if __name__ == '__main__': # verify_super_resolution_example()