From 381318444fbe606453b27bc544c8f04fd8df0e31 Mon Sep 17 00:00:00 2001 From: Piyush Ghai Date: Thu, 2 Aug 2018 11:25:42 -0700 Subject: [PATCH 1/2] Added tolerance level for assert_almost_equal for MBCC --- .../nightly/model_backwards_compatibility_check/common.py | 2 ++ .../model_backwards_compat_inference.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/nightly/model_backwards_compatibility_check/common.py b/tests/nightly/model_backwards_compatibility_check/common.py index 4c61cc4e3267..8950a9270839 100644 --- a/tests/nightly/model_backwards_compatibility_check/common.py +++ b/tests/nightly/model_backwards_compatibility_check/common.py @@ -41,6 +41,8 @@ backslash = '/' s3 = boto3.resource('s3') ctx = mx.cpu(0) +atol_default = 1e-5 +rtol_default = 1e-5 def get_model_path(model_name): diff --git a/tests/nightly/model_backwards_compatibility_check/model_backwards_compat_inference.py b/tests/nightly/model_backwards_compatibility_check/model_backwards_compat_inference.py index ae368e3a0fc6..5d63e7e9bca3 100644 --- a/tests/nightly/model_backwards_compatibility_check/model_backwards_compat_inference.py +++ b/tests/nightly/model_backwards_compatibility_check/model_backwards_compat_inference.py @@ -44,7 +44,7 @@ def test_module_checkpoint_api(): old_inference_results = load_inference_results(model_name) inference_results = loaded_model.predict(data_iter) # Check whether they are equal or not ? - assert_almost_equal(inference_results.asnumpy(), old_inference_results.asnumpy()) + assert_almost_equal(inference_results.asnumpy(), old_inference_results.asnumpy(), rtol=rtol_default, atol=atol_default) clean_model_files(model_files, model_name) logging.info('=================================') @@ -69,7 +69,7 @@ def test_lenet_gluon_load_params_api(): loaded_model.load_params(model_name + '-params') output = loaded_model(test_data) old_inference_results = mx.nd.load(model_name + '-inference')['inference'] - assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy()) + assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy(), rtol=rtol_default, atol=atol_default) clean_model_files(model_files, model_name) logging.info('=================================') logging.info('Assertion passed for model : %s' % model_name) @@ -92,7 +92,7 @@ def test_lenet_gluon_hybrid_imports_api(): loaded_model = gluon.SymbolBlock.imports(model_name + '-symbol.json', ['data'], model_name + '-0000.params') output = loaded_model(test_data) old_inference_results = mx.nd.load(model_name + '-inference')['inference'] - assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy()) + assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy(), rtol=rtol_default, atol=atol_default) clean_model_files(model_files, model_name) logging.info('=================================') logging.info('Assertion passed for model : %s' % model_name) @@ -124,7 +124,7 @@ def test_lstm_gluon_load_parameters_api(): loaded_model.load_parameters(model_name + '-params') output = loaded_model(test_data) old_inference_results = mx.nd.load(model_name + '-inference')['inference'] - assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy()) + assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy(), rtol=rtol_default, atol=atol_default) clean_model_files(model_files, model_name) logging.info('=================================') logging.info('Assertion passed for model : %s' % model_name) From bb271bb62268fd63fcae8fa6b1e8b33310ee7611 Mon Sep 17 00:00:00 2001 From: Piyush Ghai Date: Thu, 2 Aug 2018 14:27:00 -0700 Subject: [PATCH 2/2] Nudge to CI