From cb9c291e921eb753780275a7753787e3e69d8f68 Mon Sep 17 00:00:00 2001 From: Rajeshii Date: Sun, 3 Feb 2019 03:19:37 +0000 Subject: [PATCH 1/2] exclude concat for gpu quantization --- example/quantization/imagenet_gen_qsym.py | 10 ++++++++++ tests/python/quantization/test_quantization.py | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/example/quantization/imagenet_gen_qsym.py b/example/quantization/imagenet_gen_qsym.py index 8a2818c4bca0..41713c3c3f51 100644 --- a/example/quantization/imagenet_gen_qsym.py +++ b/example/quantization/imagenet_gen_qsym.py @@ -155,6 +155,16 @@ def save_params(fname, arg_params, aux_params, logger=None): if args.ctx == 'gpu': calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1 or name.find('fc') != -1) + excluded_sym_names += ['ch_concat_3a_chconcat', + 'ch_concat_3b_chconcat', + 'ch_concat_3c_chconcat', + 'ch_concat_4a_chconcat', + 'ch_concat_4b_chconcat', + 'ch_concat_4c_chconcat', + 'ch_concat_4d_chconcat', + 'ch_concat_4e_chconcat', + 'ch_concat_5a_chconcat', + 'ch_concat_5b_chconcat'] else: calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1) excluded_sym_names += ['flatten', 'fc1'] diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 3ff4b69302fb..178505ee0aa8 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -593,7 +593,8 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape): excluded_names = [] if mx.current_context() == mx.cpu(): excluded_names += ['fc'] - excluded_names += ['concat'] + else: + excluded_names += ['concat'] optional_names = ['pool0'] for skip_optional_names in [False, True]: From 088c5303aed5df268e7942ea80a49bea1e7f977f Mon Sep 17 00:00:00 2001 From: Rajeshii Date: Wed, 6 Feb 2019 08:07:55 +0000 Subject: [PATCH 2/2] remove quantized_concat test in non-subgraph flow --- tests/python/quantization/test_quantization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 178505ee0aa8..3ff4b69302fb 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -593,8 +593,7 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape): excluded_names = [] if mx.current_context() == mx.cpu(): excluded_names += ['fc'] - else: - excluded_names += ['concat'] + excluded_names += ['concat'] optional_names = ['pool0'] for skip_optional_names in [False, True]: