From 14dc860f2a9a62a0ca8843d4d0b06c14b5248e85 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Thu, 19 Nov 2020 07:38:51 +0000 Subject: [PATCH 01/10] initial commit --- python/mxnet/gluon/block.py | 145 ++++++++++++++++++++++- tests/python/unittest/test_gluon_save.py | 59 +++++++++ 2 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 tests/python/unittest/test_gluon_save.py diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 4ab72902162c..5241cb90ca91 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -661,6 +661,149 @@ def hybridize(self, active=True, **kwargs): for cld in self._children.values(): cld.hybridize(active, **kwargs) + def save(self, prefix): + """Save the model architecture and parameters to load again later + + Saves the model architecture as a nested dictionary where each Block + in the model is a dictionary and its children are sub-dictionaries. + + Each Block is uniquely identified by Block class name and a unique ID. + We save the child's name that that parent uses for it to restore later + in order to match the saved parameters. + + Recursively traverses a Block's children in order (since its an + OrderedDict) and uses the unique ID to denote that specific Block. + Assumes that the model is created in an identical order every time. + If the model is not able to be recreated deterministically do not + use this set of APIs to save/load your model. + + For HybridBlocks, the cached_graph (Symbol & inputs) if + it has already been hybridized. + + Parameters + ---------- + prefix : str + The prefix to use in filenames for saving this model: + -model.json and -model.params + """ + # create empty model structure + model = {} + def _save_cached_graphs(blk, index, structure): + # create new entry for this block + mdl = {'orig_name': blk.name} + # encode unique name based on block type and ID + name = type(blk).__name__.lower() + structure[name+str(index[0])] = mdl + if isinstance(blk, mx.gluon.nn.HybridBlock): + if blk._cached_graph: + # save in/out formats + mdl['in_format'] = blk._in_format + mdl['out_format'] = blk._out_format + # save cached graph & input symbols + syms, out = blk._cached_graph + mdl_syms = [] + for sym in syms: + mdl_syms.append(sym.tojson()) + mdl['inputs'] = mdl_syms + mdl['symbol'] = out.tojson() + mdl['hybridized'] = True + else: + mdl['hybridized'] = False + children = dict() + mdl['children'] = children + # recursively save children + for ch_name, child in blk._children.items(): + index[0] += 1 + # save child's original name in this block's map + children[child.name] = ch_name + _save_cached_graphs(child, index, mdl) + # save top-level block + index = [0] + _save_cached_graphs(self, index, model) + # save model + fp = open(prefix+'-model.json','w') + json.dump(model, fp) + fp.close() + # save params + self.save_parameters(prefix+'-model.params') + + def load(self, prefix): + """Load a model saved using the `save` API + + Reconfigures a model using the saved configuration. This function + does not regenerate the model architecture. It resets the children's + names as they were when saved in order to match the names of the + saved parameters. + + This function assumes the Blocks in the model were created in the same + order they were when the model was saved. This is because each Block is + uniquely identified by Block class name and a unique ID in order (since + its an OrderedDict) and uses the unique ID to denote that specific Block. + Assumes that the model is created in an identical order every time. + If the model is not able to be recreated deterministically do not + use this set of APIs to save/load your model. + + For HybridBlocks, the cached_graph (Symbol & inputs) and settings are + restored if it had been hybridized before saving. + + Parameters + ---------- + prefix : str + The prefix to use in filenames for loading this model: + -model.json and -model.params + """ + # load model json from file + fp = open(prefix+'-model.json') + model = json.load(fp) + fp.close() + def _load_cached_graphs(blk, index, structure): + # get block name + name = type(blk).__name__.lower() + # lookup previous encoded name based on block type and ID + mdl = structure[name+str(index[0])] + # rename block to what it was when saved + blk._name = mdl['orig_name'] + if isinstance(blk, mx.gluon.nn.HybridBlock): + if mdl['hybridized']: + # restore in/out formats + blk._in_format = mdl['in_format'] + blk._out_format = mdl['out_format'] + # get saved symbol + out = mx.sym.load_json(mdl['symbol']) + syms = [] + # recreate inputs for this symbol + for inp in mdl['inputs']: + syms.append(mx.sym.load_json(inp)) + # reset cached_graph and active status + blk._cached_graph = (syms, out) + blk._active = True + # rename params with updated block name + pnames = list(blk.params.keys()) + for p in pnames: + param = blk.params._params[p] + new_name = blk.name +'_'+ p[len(blk.params._prefix):] + blk.params._params.pop(p) + blk.params._params[new_name] = param + # recursively reload children + for ch_name, child in blk._children.items(): + index[0] += 1 + _load_cached_graphs(child, index, mdl) + # current set of child names + ch_names = list(blk._children.keys()) + # original child names + children = mdl['children'] + # loop and remap children with original names + for ch_name in ch_names: + child = blk._children[ch_name] + blk._children.pop(ch_name) + orig_name = children[child.name] + blk._children[orig_name] = child + # load top-level block + index = [0] + _load_cached_graphs(self, index, model) + # load params + self.load_parameters(prefix+'-model.params') + def cast(self, dtype): """Cast this Block to use another data type. @@ -1259,7 +1402,7 @@ def infer_shape(self, *args): def infer_type(self, *args): """Infers data type of Parameters from inputs.""" self._infer_attrs('infer_type', 'dtype', *args) - + def export(self, path, epoch=0, remove_amp_cast=True): """Export HybridBlock to json format that can be loaded by `gluon.SymbolBlock.imports`, `mxnet.mod.Module` or the C++ interface. diff --git a/tests/python/unittest/test_gluon_save.py b/tests/python/unittest/test_gluon_save.py new file mode 100644 index 000000000000..dac70968f851 --- /dev/null +++ b/tests/python/unittest/test_gluon_save.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from common import with_seed + +@with_seed() +def test_save(): + class MyBlock(mx.gluon.nn.Block): + def __init__(self, **kwargs): + super(MyBlock, self).__init__(**kwargs) + def add(self, block): + self._children[block.name + str(len(self._children))] = block + def forward(self, x, *args): + out = (x,) + args + for block in self._children.values(): + out = block(*out) + return out + + def createNet(): + inside = MyBlock() + dense = mx.gluon.nn.Dense(10) + inside.add(dense) + net = MyBlock() + net.add(inside) + net.add(mx.gluon.nn.Dense(10)) + return net + + # create and initialize model + net1 = createNet() + net1.initialize() + # hybridize (the hybridizeable blocks, ie. the Dense layers) + net1.hybridize() + x = mx.nd.empty((1,10)) + out1 = net1(x) + + # save hybridized model + net1.save('MyModel') + + # create a new model, uninitialized + net2 = createNet() + # reload hybridized model + net2.load('MyModel') + # run inference again + out2 = net2(x) + mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy()) From 999dd02d0b25ec144c6e12a25ee5a04160473ad7 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Thu, 19 Nov 2020 08:00:01 +0000 Subject: [PATCH 02/10] small tweaks --- python/mxnet/gluon/block.py | 4 ++-- tests/python/unittest/test_gluon_save.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 5241cb90ca91..4eda2ddabd72 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -677,7 +677,7 @@ def save(self, prefix): If the model is not able to be recreated deterministically do not use this set of APIs to save/load your model. - For HybridBlocks, the cached_graph (Symbol & inputs) if + For HybridBlocks, the cached_graph (Symbol & inputs) is saved if it has already been hybridized. Parameters @@ -1402,7 +1402,7 @@ def infer_shape(self, *args): def infer_type(self, *args): """Infers data type of Parameters from inputs.""" self._infer_attrs('infer_type', 'dtype', *args) - + def export(self, path, epoch=0, remove_amp_cast=True): """Export HybridBlock to json format that can be loaded by `gluon.SymbolBlock.imports`, `mxnet.mod.Module` or the C++ interface. diff --git a/tests/python/unittest/test_gluon_save.py b/tests/python/unittest/test_gluon_save.py index dac70968f851..19d6b7e9c90c 100644 --- a/tests/python/unittest/test_gluon_save.py +++ b/tests/python/unittest/test_gluon_save.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import mxnet as mx from common import with_seed @with_seed() @@ -57,3 +58,7 @@ def createNet(): # run inference again out2 = net2(x) mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy()) + +if __name__ == '__main__': + import nose + nose.runmodule() From ecaffd98025651002194b1aa5625dc336d4926cb Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Fri, 20 Nov 2020 08:11:27 +0000 Subject: [PATCH 03/10] renamed load_json to fromjson --- python/mxnet/contrib/quantization.py | 3 +- python/mxnet/gluon/block.py | 9 +- python/mxnet/symbol/symbol.py | 116 +++++++++--------- .../unittest/test_contrib_control_flow.py | 8 +- tests/python/unittest/test_gluon.py | 2 +- tests/python/unittest/test_operator.py | 8 +- 6 files changed, 72 insertions(+), 74 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 1422e07feebb..589a8048fa24 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -32,7 +32,6 @@ from ..base import c_array, c_str, mx_uint, c_str_array from ..base import NDArrayHandle, SymbolHandle from ..symbol import Symbol -from ..symbol import load as sym_load from .. import ndarray from ..ndarray import load as nd_load from ..ndarray import save as nd_save @@ -376,7 +375,7 @@ def _load_sym(sym, logger=None): symbol_file_path = os.path.join(cur_path, sym) if logger: logger.info('Loading symbol from file %s' % symbol_file_path) - return sym_load(symbol_file_path) + return Symbol.load(symbol_file_path) elif isinstance(sym, Symbol): return sym else: diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 4eda2ddabd72..892d0031fe69 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -24,6 +24,7 @@ import copy import warnings import re +import json from collections import OrderedDict, defaultdict import numpy as np @@ -694,7 +695,7 @@ def _save_cached_graphs(blk, index, structure): # encode unique name based on block type and ID name = type(blk).__name__.lower() structure[name+str(index[0])] = mdl - if isinstance(blk, mx.gluon.nn.HybridBlock): + if isinstance(blk, HybridBlock): if blk._cached_graph: # save in/out formats mdl['in_format'] = blk._in_format @@ -763,17 +764,17 @@ def _load_cached_graphs(blk, index, structure): mdl = structure[name+str(index[0])] # rename block to what it was when saved blk._name = mdl['orig_name'] - if isinstance(blk, mx.gluon.nn.HybridBlock): + if isinstance(blk, HybridBlock): if mdl['hybridized']: # restore in/out formats blk._in_format = mdl['in_format'] blk._out_format = mdl['out_format'] # get saved symbol - out = mx.sym.load_json(mdl['symbol']) + out = Symbol.fromjson(mdl['symbol']) syms = [] # recreate inputs for this symbol for inp in mdl['inputs']: - syms.append(mx.sym.load_json(inp)) + syms.append(Symbol.fromjson(inp)) # reset cached_graph and active status blk._cached_graph = (syms, out) blk._active = True diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 0f8cccd071cf..0ed8dd183e9e 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -45,7 +45,7 @@ from ._internal import SymbolBase, _set_symbol_class from ..util import is_np_shape -__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json", +__all__ = ["Symbol", "var", "Variable", "Group", "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange", "linspace", "histogram", "split_v2"] @@ -1364,12 +1364,44 @@ def save(self, fname, remove_amp_cast=True): else: check_call(_LIB.MXSymbolSaveToFile(self.handle, c_str(fname))) + def load(fname): + """Loads symbol from a JSON file. + + You can also use pickle to do the job if you only work on python. + The advantage of load/save is the file is language agnostic. + This means the file saved using save can be loaded by other language binding of mxnet. + You also get the benefit being able to directly load/save from cloud storage(S3, HDFS). + + Parameters + ---------- + fname : str + The name of the file, examples: + + - `s3://my-bucket/path/my-s3-symbol` + - `hdfs://my-bucket/path/my-hdfs-symbol` + - `/path-to/my-local-symbol` + + Returns + ------- + sym : Symbol + The loaded symbol. + + See Also + -------- + Symbol.save : Used to save symbol into file. + """ + if not isinstance(fname, string_types): + raise TypeError('fname need to be string') + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateFromFile(c_str(fname), ctypes.byref(handle))) + return Symbol(handle) + def tojson(self, remove_amp_cast=True): """Saves symbol to a JSON string. See Also -------- - symbol.load_json : Used to load symbol from JSON string. + symbol.fromjson : Used to load symbol from JSON string. """ json_str = ctypes.c_char_p() if remove_amp_cast: @@ -1380,6 +1412,29 @@ def tojson(self, remove_amp_cast=True): check_call(_LIB.MXSymbolSaveToJSON(self.handle, ctypes.byref(json_str))) return py_str(json_str.value) + def fromjson(json_str): + """Loads symbol from json string. + + Parameters + ---------- + json_str : str + A JSON string. + + Returns + ------- + sym : Symbol + The loaded symbol. + + See Also + -------- + Symbol.tojson : Used to save symbol into json string. + """ + if not isinstance(json_str, string_types): + raise TypeError('fname required to be string') + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateFromJSON(c_str(json_str), ctypes.byref(handle))) + return Symbol(handle) + @staticmethod def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing): """Helper function to get NDArray lists handles from various inputs. @@ -3022,63 +3077,6 @@ def Group(symbols, create_fn=Symbol): return create_fn(handle) -def load(fname): - """Loads symbol from a JSON file. - - You can also use pickle to do the job if you only work on python. - The advantage of load/save is the file is language agnostic. - This means the file saved using save can be loaded by other language binding of mxnet. - You also get the benefit being able to directly load/save from cloud storage(S3, HDFS). - - Parameters - ---------- - fname : str - The name of the file, examples: - - - `s3://my-bucket/path/my-s3-symbol` - - `hdfs://my-bucket/path/my-hdfs-symbol` - - `/path-to/my-local-symbol` - - Returns - ------- - sym : Symbol - The loaded symbol. - - See Also - -------- - Symbol.save : Used to save symbol into file. - """ - if not isinstance(fname, string_types): - raise TypeError('fname need to be string') - handle = SymbolHandle() - check_call(_LIB.MXSymbolCreateFromFile(c_str(fname), ctypes.byref(handle))) - return Symbol(handle) - - -def load_json(json_str): - """Loads symbol from json string. - - Parameters - ---------- - json_str : str - A JSON string. - - Returns - ------- - sym : Symbol - The loaded symbol. - - See Also - -------- - Symbol.tojson : Used to save symbol into json string. - """ - if not isinstance(json_str, string_types): - raise TypeError('fname required to be string') - handle = SymbolHandle() - check_call(_LIB.MXSymbolCreateFromJSON(c_str(json_str), ctypes.byref(handle))) - return Symbol(handle) - - # pylint: disable=no-member # pylint: disable=redefined-builtin def pow(base, exp): diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index a93c109d11df..cb223a6101d1 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -1234,7 +1234,7 @@ def verify_foreach(step, in_syms, state_syms, free_syms, out.extend(states) out = mx.sym.Group(out) js_1 = out.tojson() - out = mx.sym.load_json(js_1) + out = mx.sym.fromjson(js_1) js_2 = out.tojson() assert js_1 == js_2 arr_grads = [] @@ -1556,7 +1556,7 @@ def step_nd(in1, states): out = mx.sym.broadcast_add(out, states[0]) js_1 = out.tojson() - out = mx.sym.load_json(js_1) + out = mx.sym.fromjson(js_1) js_2 = out.tojson() assert js_1 == js_2 @@ -1631,7 +1631,7 @@ def sym_group(out): out = mx.sym.contrib.foreach(step, data, init_states) out = sym_group(out) js_1 = out.tojson() - out = mx.sym.load_json(js_1) + out = mx.sym.fromjson(js_1) js_2 = out.tojson() assert js_1 == js_2 e1 = out.bind(ctx=default_context(), args=args1, args_grad=args_grad1) @@ -1647,7 +1647,7 @@ def sym_group(out): unroll_outs.extend(states) out = mx.sym.Group(unroll_outs) js_1 = out.tojson() - out = mx.sym.load_json(js_1) + out = mx.sym.fromjson(js_1) js_2 = out.tojson() assert js_1 == js_2 e2 = out.bind(ctx=default_context(), args=args2, args_grad=args_grad2) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 0eb83401073a..abbc4cf8791a 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1917,7 +1917,7 @@ def test_legacy_save_params(): a = net(mx.sym.var('data')) a.save('test.json') net.save_params('test.params') - model = gluon.nn.SymbolBlock(outputs=mx.sym.load_json(open('test.json', 'r').read()), + model = gluon.nn.SymbolBlock(outputs=mx.sym.fromjson(open('test.json', 'r').read()), inputs=mx.sym.var('data')) model.load_params('test.params', ctx=mx.cpu()) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 92fd030ba278..994f1deb59f2 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2659,7 +2659,7 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape): net = mx.sym.Variable("data") net = mx.sym.Reshape(net, shape=shape_args, reverse=reverse) js = net.tojson() - net = mx.sym.load_json(js) + net = mx.sym.fromjson(js) _, output_shape, __ = net.infer_shape(data=src_shape) assert output_shape[0] == dst_shape, \ 'Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = %s, ' \ @@ -2728,7 +2728,7 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape): net = mx.sym.Variable("data") net = mx.sym.Reshape(net, target_shape=(2, 0)) js = net.tojson() - net = mx.sym.load_json(js) + net = mx.sym.fromjson(js) _, output_shape, __ = net.infer_shape(data=(2, 3, 5, 5)) assert(output_shape[0] == (2, 75)) # Test for Flatten @@ -2750,7 +2750,7 @@ def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shap rhs = mx.sym.Variable("rhs") net = mx.sym.reshape_like(lhs, rhs, lhs_begin=lbeg, lhs_end=lend, rhs_begin=rbeg, rhs_end=rend) js = net.tojson() - net = mx.sym.load_json(js) + net = mx.sym.fromjson(js) _, output_shape, __ = net.infer_shape(lhs=lhs_shape, rhs=rhs_shape) assert output_shape[0] == dst_shape, \ @@ -2791,7 +2791,7 @@ def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shap rhs = mx.sym.Variable("rhs") net = mx.sym.reshape_like(lhs, rhs) js = net.tojson() - net = mx.sym.load_json(js) + net = mx.sym.fromjson(js) _, output_shape, __ = net.infer_shape(lhs=(40, 30), rhs=(30,20,2)) assert(output_shape[0] == (30,20,2)) From d9005026100b77173446c522e7205ce3522a9484 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Fri, 20 Nov 2020 08:53:44 +0000 Subject: [PATCH 04/10] fixed fromjson --- python/mxnet/gluon/block.py | 6 +- python/mxnet/symbol/symbol.py | 114 +++++++++++++++++----------------- 2 files changed, 61 insertions(+), 59 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 892d0031fe69..0fe5fe313319 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -30,7 +30,7 @@ from ..base import mx_real_t, MXNetError from .. import symbol, ndarray, initializer, np_symbol -from ..symbol import Symbol +from ..symbol import Symbol, fromjson from ..ndarray import NDArray from .. import name as _name from .parameter import Parameter, ParameterDict, DeferredInitializationError @@ -770,11 +770,11 @@ def _load_cached_graphs(blk, index, structure): blk._in_format = mdl['in_format'] blk._out_format = mdl['out_format'] # get saved symbol - out = Symbol.fromjson(mdl['symbol']) + out = fromjson(mdl['symbol']) syms = [] # recreate inputs for this symbol for inp in mdl['inputs']: - syms.append(Symbol.fromjson(inp)) + syms.append(fromjson(inp)) # reset cached_graph and active status blk._cached_graph = (syms, out) blk._active = True diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 0ed8dd183e9e..54e18547ab0d 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -45,7 +45,7 @@ from ._internal import SymbolBase, _set_symbol_class from ..util import is_np_shape -__all__ = ["Symbol", "var", "Variable", "Group", +__all__ = ["Symbol", "var", "Variable", "Group", "load", "fromjson", "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange", "linspace", "histogram", "split_v2"] @@ -1364,38 +1364,6 @@ def save(self, fname, remove_amp_cast=True): else: check_call(_LIB.MXSymbolSaveToFile(self.handle, c_str(fname))) - def load(fname): - """Loads symbol from a JSON file. - - You can also use pickle to do the job if you only work on python. - The advantage of load/save is the file is language agnostic. - This means the file saved using save can be loaded by other language binding of mxnet. - You also get the benefit being able to directly load/save from cloud storage(S3, HDFS). - - Parameters - ---------- - fname : str - The name of the file, examples: - - - `s3://my-bucket/path/my-s3-symbol` - - `hdfs://my-bucket/path/my-hdfs-symbol` - - `/path-to/my-local-symbol` - - Returns - ------- - sym : Symbol - The loaded symbol. - - See Also - -------- - Symbol.save : Used to save symbol into file. - """ - if not isinstance(fname, string_types): - raise TypeError('fname need to be string') - handle = SymbolHandle() - check_call(_LIB.MXSymbolCreateFromFile(c_str(fname), ctypes.byref(handle))) - return Symbol(handle) - def tojson(self, remove_amp_cast=True): """Saves symbol to a JSON string. @@ -1412,29 +1380,6 @@ def tojson(self, remove_amp_cast=True): check_call(_LIB.MXSymbolSaveToJSON(self.handle, ctypes.byref(json_str))) return py_str(json_str.value) - def fromjson(json_str): - """Loads symbol from json string. - - Parameters - ---------- - json_str : str - A JSON string. - - Returns - ------- - sym : Symbol - The loaded symbol. - - See Also - -------- - Symbol.tojson : Used to save symbol into json string. - """ - if not isinstance(json_str, string_types): - raise TypeError('fname required to be string') - handle = SymbolHandle() - check_call(_LIB.MXSymbolCreateFromJSON(c_str(json_str), ctypes.byref(handle))) - return Symbol(handle) - @staticmethod def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing): """Helper function to get NDArray lists handles from various inputs. @@ -3077,6 +3022,63 @@ def Group(symbols, create_fn=Symbol): return create_fn(handle) +def load(fname): + """Loads symbol from a JSON file. + + You can also use pickle to do the job if you only work on python. + The advantage of load/save is the file is language agnostic. + This means the file saved using save can be loaded by other language binding of mxnet. + You also get the benefit being able to directly load/save from cloud storage(S3, HDFS). + + Parameters + ---------- + fname : str + The name of the file, examples: + + - `s3://my-bucket/path/my-s3-symbol` + - `hdfs://my-bucket/path/my-hdfs-symbol` + - `/path-to/my-local-symbol` + + Returns + ------- + sym : Symbol + The loaded symbol. + + See Also + -------- + Symbol.save : Used to save symbol into file. + """ + if not isinstance(fname, string_types): + raise TypeError('fname need to be string') + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateFromFile(c_str(fname), ctypes.byref(handle))) + return Symbol(handle) + + +def fromjson(json_str): + """Loads symbol from json string. + + Parameters + ---------- + json_str : str + A JSON string. + + Returns + ------- + sym : Symbol + The loaded symbol. + + See Also + -------- + Symbol.tojson : Used to save symbol into json string. + """ + if not isinstance(json_str, string_types): + raise TypeError('fname required to be string') + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateFromJSON(c_str(json_str), ctypes.byref(handle))) + return Symbol(handle) + + # pylint: disable=no-member # pylint: disable=redefined-builtin def pow(base, exp): From 281bfec1c1b230e79e76cfb275cf7c26c13c1a6c Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Fri, 20 Nov 2020 09:45:29 +0000 Subject: [PATCH 05/10] fixed sanity --- python/mxnet/gluon/block.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 0fe5fe313319..aa515379d9fe 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -722,7 +722,7 @@ def _save_cached_graphs(blk, index, structure): index = [0] _save_cached_graphs(self, index, model) # save model - fp = open(prefix+'-model.json','w') + fp = open(prefix+'-model.json', 'w') json.dump(model, fp) fp.close() # save params @@ -734,7 +734,7 @@ def load(self, prefix): Reconfigures a model using the saved configuration. This function does not regenerate the model architecture. It resets the children's names as they were when saved in order to match the names of the - saved parameters. + saved parameters. This function assumes the Blocks in the model were created in the same order they were when the model was saved. This is because each Block is @@ -784,7 +784,7 @@ def _load_cached_graphs(blk, index, structure): param = blk.params._params[p] new_name = blk.name +'_'+ p[len(blk.params._prefix):] blk.params._params.pop(p) - blk.params._params[new_name] = param + blk.params._params[new_name] = param # recursively reload children for ch_name, child in blk._children.items(): index[0] += 1 From 774fbcbc00933352808033445f0a463798b15478 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Fri, 20 Nov 2020 17:46:52 +0000 Subject: [PATCH 06/10] changed tests data to zeros --- tests/python/unittest/test_gluon_save.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_gluon_save.py b/tests/python/unittest/test_gluon_save.py index 19d6b7e9c90c..95ae7d91b256 100644 --- a/tests/python/unittest/test_gluon_save.py +++ b/tests/python/unittest/test_gluon_save.py @@ -45,7 +45,7 @@ def createNet(): net1.initialize() # hybridize (the hybridizeable blocks, ie. the Dense layers) net1.hybridize() - x = mx.nd.empty((1,10)) + x = mx.nd.zeros((1,10)) out1 = net1(x) # save hybridized model From 43b390f7ce541c647f56515a349f8bac93fa4501 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Sat, 12 Dec 2020 02:02:24 +0000 Subject: [PATCH 07/10] undo renaming in quantization --- python/mxnet/contrib/quantization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 589a8048fa24..1422e07feebb 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -32,6 +32,7 @@ from ..base import c_array, c_str, mx_uint, c_str_array from ..base import NDArrayHandle, SymbolHandle from ..symbol import Symbol +from ..symbol import load as sym_load from .. import ndarray from ..ndarray import load as nd_load from ..ndarray import save as nd_save @@ -375,7 +376,7 @@ def _load_sym(sym, logger=None): symbol_file_path = os.path.join(cur_path, sym) if logger: logger.info('Loading symbol from file %s' % symbol_file_path) - return Symbol.load(symbol_file_path) + return sym_load(symbol_file_path) elif isinstance(sym, Symbol): return sym else: From 162dd884f3ade4176b12d13a1784a9d5e20b1f59 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Sat, 12 Dec 2020 02:04:30 +0000 Subject: [PATCH 08/10] fix indent --- python/mxnet/symbol/symbol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 54e18547ab0d..1ab34cca7265 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -3035,9 +3035,9 @@ def load(fname): fname : str The name of the file, examples: - - `s3://my-bucket/path/my-s3-symbol` - - `hdfs://my-bucket/path/my-hdfs-symbol` - - `/path-to/my-local-symbol` + - `s3://my-bucket/path/my-s3-symbol` + - `hdfs://my-bucket/path/my-hdfs-symbol` + - `/path-to/my-local-symbol` Returns ------- From 0e42b192b7585be08afe549cc790ec0552f44dec Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Sat, 12 Dec 2020 02:14:43 +0000 Subject: [PATCH 09/10] changed to with open --- python/mxnet/gluon/block.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index aa515379d9fe..d89bc89bd2c9 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -722,9 +722,8 @@ def _save_cached_graphs(blk, index, structure): index = [0] _save_cached_graphs(self, index, model) # save model - fp = open(prefix+'-model.json', 'w') - json.dump(model, fp) - fp.close() + with open(prefix+'-model.json', 'w') as fp: + json.dump(model, fp) # save params self.save_parameters(prefix+'-model.params') @@ -754,9 +753,9 @@ def load(self, prefix): -model.json and -model.params """ # load model json from file - fp = open(prefix+'-model.json') - model = json.load(fp) - fp.close() + with open(prefix+'-model.json') as fp: + model = json.load(fp) + def _load_cached_graphs(blk, index, structure): # get block name name = type(blk).__name__.lower() From 6d0a52401f8667020384ffa2d4cf10c6ff092403 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Mon, 14 Dec 2020 22:12:57 +0000 Subject: [PATCH 10/10] undo load_json -> fromjson change --- python/mxnet/gluon/block.py | 6 +++--- python/mxnet/symbol/symbol.py | 6 +++--- tests/python/unittest/test_contrib_control_flow.py | 8 ++++---- tests/python/unittest/test_gluon.py | 2 +- tests/python/unittest/test_operator.py | 8 ++++---- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index d89bc89bd2c9..d415c5fea511 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -30,7 +30,7 @@ from ..base import mx_real_t, MXNetError from .. import symbol, ndarray, initializer, np_symbol -from ..symbol import Symbol, fromjson +from ..symbol import Symbol, load_json from ..ndarray import NDArray from .. import name as _name from .parameter import Parameter, ParameterDict, DeferredInitializationError @@ -769,11 +769,11 @@ def _load_cached_graphs(blk, index, structure): blk._in_format = mdl['in_format'] blk._out_format = mdl['out_format'] # get saved symbol - out = fromjson(mdl['symbol']) + out = load_json(mdl['symbol']) syms = [] # recreate inputs for this symbol for inp in mdl['inputs']: - syms.append(fromjson(inp)) + syms.append(load_json(inp)) # reset cached_graph and active status blk._cached_graph = (syms, out) blk._active = True diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 1ab34cca7265..0f8cccd071cf 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -45,7 +45,7 @@ from ._internal import SymbolBase, _set_symbol_class from ..util import is_np_shape -__all__ = ["Symbol", "var", "Variable", "Group", "load", "fromjson", +__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json", "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange", "linspace", "histogram", "split_v2"] @@ -1369,7 +1369,7 @@ def tojson(self, remove_amp_cast=True): See Also -------- - symbol.fromjson : Used to load symbol from JSON string. + symbol.load_json : Used to load symbol from JSON string. """ json_str = ctypes.c_char_p() if remove_amp_cast: @@ -3055,7 +3055,7 @@ def load(fname): return Symbol(handle) -def fromjson(json_str): +def load_json(json_str): """Loads symbol from json string. Parameters diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index cb223a6101d1..a93c109d11df 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -1234,7 +1234,7 @@ def verify_foreach(step, in_syms, state_syms, free_syms, out.extend(states) out = mx.sym.Group(out) js_1 = out.tojson() - out = mx.sym.fromjson(js_1) + out = mx.sym.load_json(js_1) js_2 = out.tojson() assert js_1 == js_2 arr_grads = [] @@ -1556,7 +1556,7 @@ def step_nd(in1, states): out = mx.sym.broadcast_add(out, states[0]) js_1 = out.tojson() - out = mx.sym.fromjson(js_1) + out = mx.sym.load_json(js_1) js_2 = out.tojson() assert js_1 == js_2 @@ -1631,7 +1631,7 @@ def sym_group(out): out = mx.sym.contrib.foreach(step, data, init_states) out = sym_group(out) js_1 = out.tojson() - out = mx.sym.fromjson(js_1) + out = mx.sym.load_json(js_1) js_2 = out.tojson() assert js_1 == js_2 e1 = out.bind(ctx=default_context(), args=args1, args_grad=args_grad1) @@ -1647,7 +1647,7 @@ def sym_group(out): unroll_outs.extend(states) out = mx.sym.Group(unroll_outs) js_1 = out.tojson() - out = mx.sym.fromjson(js_1) + out = mx.sym.load_json(js_1) js_2 = out.tojson() assert js_1 == js_2 e2 = out.bind(ctx=default_context(), args=args2, args_grad=args_grad2) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index abbc4cf8791a..0eb83401073a 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1917,7 +1917,7 @@ def test_legacy_save_params(): a = net(mx.sym.var('data')) a.save('test.json') net.save_params('test.params') - model = gluon.nn.SymbolBlock(outputs=mx.sym.fromjson(open('test.json', 'r').read()), + model = gluon.nn.SymbolBlock(outputs=mx.sym.load_json(open('test.json', 'r').read()), inputs=mx.sym.var('data')) model.load_params('test.params', ctx=mx.cpu()) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 994f1deb59f2..92fd030ba278 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2659,7 +2659,7 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape): net = mx.sym.Variable("data") net = mx.sym.Reshape(net, shape=shape_args, reverse=reverse) js = net.tojson() - net = mx.sym.fromjson(js) + net = mx.sym.load_json(js) _, output_shape, __ = net.infer_shape(data=src_shape) assert output_shape[0] == dst_shape, \ 'Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = %s, ' \ @@ -2728,7 +2728,7 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape): net = mx.sym.Variable("data") net = mx.sym.Reshape(net, target_shape=(2, 0)) js = net.tojson() - net = mx.sym.fromjson(js) + net = mx.sym.load_json(js) _, output_shape, __ = net.infer_shape(data=(2, 3, 5, 5)) assert(output_shape[0] == (2, 75)) # Test for Flatten @@ -2750,7 +2750,7 @@ def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shap rhs = mx.sym.Variable("rhs") net = mx.sym.reshape_like(lhs, rhs, lhs_begin=lbeg, lhs_end=lend, rhs_begin=rbeg, rhs_end=rend) js = net.tojson() - net = mx.sym.fromjson(js) + net = mx.sym.load_json(js) _, output_shape, __ = net.infer_shape(lhs=lhs_shape, rhs=rhs_shape) assert output_shape[0] == dst_shape, \ @@ -2791,7 +2791,7 @@ def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shap rhs = mx.sym.Variable("rhs") net = mx.sym.reshape_like(lhs, rhs) js = net.tojson() - net = mx.sym.fromjson(js) + net = mx.sym.load_json(js) _, output_shape, __ = net.infer_shape(lhs=(40, 30), rhs=(30,20,2)) assert(output_shape[0] == (30,20,2))