diff --git a/tests/nightly/model_backwards_compatibility_check/common.py b/tests/nightly/model_backwards_compatibility_check/common.py index 4c61cc4e3267..8950a9270839 100644 --- a/tests/nightly/model_backwards_compatibility_check/common.py +++ b/tests/nightly/model_backwards_compatibility_check/common.py @@ -41,6 +41,8 @@ backslash = '/' s3 = boto3.resource('s3') ctx = mx.cpu(0) +atol_default = 1e-5 +rtol_default = 1e-5 def get_model_path(model_name): diff --git a/tests/nightly/model_backwards_compatibility_check/model_backwards_compat_inference.py b/tests/nightly/model_backwards_compatibility_check/model_backwards_compat_inference.py index ae368e3a0fc6..5d63e7e9bca3 100644 --- a/tests/nightly/model_backwards_compatibility_check/model_backwards_compat_inference.py +++ b/tests/nightly/model_backwards_compatibility_check/model_backwards_compat_inference.py @@ -44,7 +44,7 @@ def test_module_checkpoint_api(): old_inference_results = load_inference_results(model_name) inference_results = loaded_model.predict(data_iter) # Check whether they are equal or not ? - assert_almost_equal(inference_results.asnumpy(), old_inference_results.asnumpy()) + assert_almost_equal(inference_results.asnumpy(), old_inference_results.asnumpy(), rtol=rtol_default, atol=atol_default) clean_model_files(model_files, model_name) logging.info('=================================') @@ -69,7 +69,7 @@ def test_lenet_gluon_load_params_api(): loaded_model.load_params(model_name + '-params') output = loaded_model(test_data) old_inference_results = mx.nd.load(model_name + '-inference')['inference'] - assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy()) + assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy(), rtol=rtol_default, atol=atol_default) clean_model_files(model_files, model_name) logging.info('=================================') logging.info('Assertion passed for model : %s' % model_name) @@ -92,7 +92,7 @@ def test_lenet_gluon_hybrid_imports_api(): loaded_model = gluon.SymbolBlock.imports(model_name + '-symbol.json', ['data'], model_name + '-0000.params') output = loaded_model(test_data) old_inference_results = mx.nd.load(model_name + '-inference')['inference'] - assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy()) + assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy(), rtol=rtol_default, atol=atol_default) clean_model_files(model_files, model_name) logging.info('=================================') logging.info('Assertion passed for model : %s' % model_name) @@ -124,7 +124,7 @@ def test_lstm_gluon_load_parameters_api(): loaded_model.load_parameters(model_name + '-params') output = loaded_model(test_data) old_inference_results = mx.nd.load(model_name + '-inference')['inference'] - assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy()) + assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy(), rtol=rtol_default, atol=atol_default) clean_model_files(model_files, model_name) logging.info('=================================') logging.info('Assertion passed for model : %s' % model_name)