From 5d96a16c07c49ad89b9ccfdbc6a74f07dd71a79d Mon Sep 17 00:00:00 2001 From: chinakook Date: Wed, 2 Sep 2020 04:58:09 +0800 Subject: [PATCH] fix block.export (#17970) * fix block.export ```net.hybridize``` may optimize out some ops. These ops are alive in nn.Block(also nn.HybridBlock), but its names are not contained in symbol's ```arg_names``` list. So ignore these ops except that their name are end with 'running_mean' or 'running_var'. * Update block.py let user can save their extra param. * add allow_extra add allow_extra to let user decide whether to save extra parameters or not. * Update block.py add moving_mean and moving_var when export model with SymbolBlock * Update python/mxnet/gluon/block.py typo Co-authored-by: Sheng Zha * Update block.py * Update block.py * Update python/mxnet/gluon/block.py Co-authored-by: Leonard Lausen Co-authored-by: Sheng Zha Co-authored-by: Leonard Lausen --- python/mxnet/gluon/block.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 9772e2394486..41ef2cb15d89 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1238,8 +1238,11 @@ def export(self, path, epoch=0, remove_amp_cast=True): if name in arg_names: arg_dict['arg:{}'.format(name)] = param._reduce() else: - assert name in aux_names - arg_dict['aux:{}'.format(name)] = param._reduce() + if name not in aux_names: + warnings.warn('Parameter "{name}" is not found in the graph. ' + .format(name=name), stacklevel=3) + else: + arg_dict['aux:%s'%name] = param._reduce() save_fn = _mx_npx.save if is_np_array() else ndarray.save save_fn('%s-%04d.params'%(path, epoch), arg_dict)