From 9a81a402bd322342e444e9def910f66e01a6cb0c Mon Sep 17 00:00:00 2001 From: chinakook Date: Wed, 2 Sep 2020 04:58:09 +0800 Subject: [PATCH] fix block.export (#17970) * fix block.export ```net.hybridize``` may optimize out some ops. These ops are alive in nn.Block(also nn.HybridBlock), but its names are not contained in symbol's ```arg_names``` list. So ignore these ops except that their name are end with 'running_mean' or 'running_var'. * Update block.py let user can save their extra param. * add allow_extra add allow_extra to let user decide whether to save extra parameters or not. * Update block.py add moving_mean and moving_var when export model with SymbolBlock * Update python/mxnet/gluon/block.py typo Co-authored-by: Sheng Zha * Update block.py * Update block.py * Update python/mxnet/gluon/block.py Co-authored-by: Leonard Lausen Co-authored-by: Sheng Zha Co-authored-by: Leonard Lausen --- python/mxnet/gluon/block.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index bed6679be2e6..d61dbaddbc7b 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1195,12 +1195,16 @@ def export(self, path, epoch=0, remove_amp_cast=True): arg_names = set(sym.list_arguments()) aux_names = set(sym.list_auxiliary_states()) arg_dict = {} - for name, param in self.collect_params().items(): - if name in arg_names: - arg_dict['arg:%s'%name] = param._reduce() - else: - assert name in aux_names - arg_dict['aux:%s'%name] = param._reduce() + for is_arg, name, param in self._cached_op_args: + if not is_arg: + if name in arg_names: + arg_dict['arg:{}'.format(name)] = param._reduce() + else: + if name not in aux_names: + warnings.warn('Parameter "{name}" is not found in the graph. ' + .format(name=name), stacklevel=3) + else: + arg_dict['aux:%s'%name] = param._reduce() save_fn = _mx_npx.save if is_np_array() else ndarray.save save_fn('%s-%04d.params'%(path, epoch), arg_dict)