From 94aa421ab7fb7d70fd18cea22859b227908128b7 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Wed, 2 Mar 2022 11:16:44 +0100 Subject: [PATCH 1/2] [v1.x] Reduce after quantization memory usage --- python/mxnet/contrib/quantization.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 732938e3b3ad..47138d6147cc 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -996,14 +996,20 @@ def __exit__(self, exc_type, exc_value, traceback): save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()}) nd_save(param_name, save_dict) + for k,v in net.collect_params().items(): + v.grad_req = 'null' net.collect_params().load(param_name, cast_dtype=True, dtype_source='saved') net.collect_params().reset_ctx(ctx) + if quantized_dtype == 'auto': mx.nd.waitall() net.optimize_for(x=data_nd, backend="MKLDNNShiftedQuantization") tmp_file = os.path.join(tmpdirname, 'model') net.export(tmp_file) - net = SymbolBlock.imports(tmp_file + '-symbol.json', data_names, tmp_file + '-0000.params') + net = SymbolBlock.imports(tmp_file + '-symbol.json', data_names) + for k,v in net.collect_params().items(): + v.grad_req = 'null' + net.collect_params().load(tmp_file + '-0000.params', cast_dtype=True, dtype_source='saved') return net def quantize_net(network, quantized_dtype='auto', quantize_mode='full', From 5771578ab62c6f7c3cb2b81b491881d5ebe83585 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 3 Mar 2022 09:42:44 +0100 Subject: [PATCH 2/2] fix sanity --- python/mxnet/contrib/quantization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 47138d6147cc..609e0c2fcd15 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -996,7 +996,7 @@ def __exit__(self, exc_type, exc_value, traceback): save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()}) nd_save(param_name, save_dict) - for k,v in net.collect_params().items(): + for _, v in net.collect_params().items(): v.grad_req = 'null' net.collect_params().load(param_name, cast_dtype=True, dtype_source='saved') net.collect_params().reset_ctx(ctx) @@ -1007,7 +1007,7 @@ def __exit__(self, exc_type, exc_value, traceback): tmp_file = os.path.join(tmpdirname, 'model') net.export(tmp_file) net = SymbolBlock.imports(tmp_file + '-symbol.json', data_names) - for k,v in net.collect_params().items(): + for _, v in net.collect_params().items(): v.grad_req = 'null' net.collect_params().load(tmp_file + '-0000.params', cast_dtype=True, dtype_source='saved') return net