From 765297bf5d00054fdc40f8987e8238fcfa5eca60 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Mon, 14 Dec 2020 11:06:39 -0800 Subject: [PATCH] Save/Load Gluon Blocks & HybridBlocks (#19564) --- python/mxnet/gluon/block.py | 140 +++++++++++++++++- python/mxnet/symbol/symbol.py | 6 +- .../unittest/test_contrib_control_flow.py | 4 +- .../python/unittest/test_contrib_operator.py | 2 +- tests/python/unittest/test_gluon_save.py | 60 ++++++++ tests/python/unittest/test_operator.py | 8 +- 6 files changed, 209 insertions(+), 11 deletions(-) create mode 100644 tests/python/unittest/test_gluon_save.py diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 28e6f3622ba7..06025b116350 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -29,13 +29,14 @@ import contextvars import re +import json import numpy as np from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \ profiler as _profiler, context as _context from ..symbol.numpy import _symbol as np_symbol -from ..symbol import Symbol +from ..symbol import Symbol, fromjson from ..ndarray import NDArray from .parameter import Parameter, DeferredInitializationError from .utils import _indent, _brief_print_list, HookHandle, shape_is_known @@ -573,6 +574,143 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, for v in params.values(): v.initialize(None, ctx, init, force_reinit=force_reinit) + def save(self, prefix): + """Save the model architecture and parameters to load again later + + Saves the model architecture as a nested dictionary where each Block + in the model is a dictionary and its children are sub-dictionaries. + + Each Block is uniquely identified by Block class name and a unique ID. + We save each Block's parameter UUID to restore later in order to match + the saved parameters. + + Recursively traverses a Block's children in order (since its an + OrderedDict) and uses the unique ID to denote that specific Block. + + Assumes that the model is created in an identical order every time. + If the model is not able to be recreated deterministically do not + use this set of APIs to save/load your model. + + For HybridBlocks, the cached_graph is saved (Symbol & inputs) if + it has already been hybridized. + + Parameters + ---------- + prefix : str + The prefix to use in filenames for saving this model: + -model.json and -model.params + """ + # create empty model structure + model = {} + def _save_cached_graphs(blk, structure, index=0): + # create new entry for this block + mdl = {} + # encode unique name based on block type and ID + name = type(blk).__name__.lower() + structure[name+str(index)] = mdl + index += 1 + if isinstance(blk, HybridBlock): + if blk._cached_graph: + # save in/out formats + mdl['in_format'] = blk._in_format + mdl['out_format'] = blk._out_format + # save cached graph & input symbols + syms, out = blk._cached_graph + mdl_syms = [] + for sym in syms: + mdl_syms.append(sym.tojson()) + mdl['inputs'] = mdl_syms + mdl['symbol'] = out.tojson() + mdl['hybridized'] = True + else: + mdl['hybridized'] = False + # save param uuids + pmap = {} + mdl['params'] = pmap + pnames = list(blk.params.keys()) + for p in pnames: + param = blk.params[p] + pmap[p] = param._uuid + # recursively save children + for child in blk._children.values(): + index = _save_cached_graphs(child(), mdl, index) + # return latest index (ie. block count) + return index + + # save top-level block + _save_cached_graphs(self, model) + # save model + with open(prefix+'-model.json', 'w') as fp: + json.dump(model, fp) + # save params + self.save_parameters('MyModel-model.params') + + def load(self, prefix): + """Load a model saved using the `save` API + + Reconfigures a model using the saved configuration. This function + does not regenerate the model architecture. It resets each Block's + parameter UUIDs as they were when saved in order to match the names of the + saved parameters. + + This function assumes the Blocks in the model were created in the same + order they were when the model was saved. This is because each Block is + uniquely identified by Block class name and a unique ID in order (since + its an OrderedDict) and uses the unique ID to denote that specific Block. + + Assumes that the model is created in an identical order every time. + If the model is not able to be recreated deterministically do not + use this set of APIs to save/load your model. + + For HybridBlocks, the cached_graph (Symbol & inputs) and settings are + restored if it had been hybridized before saving. + + Parameters + ---------- + prefix : str + The prefix to use in filenames for loading this model: + -model.json and -model.params + """ + # load model json from file + with open(prefix+'-model.json') as fp: + model = json.load(fp) + + def _load_cached_graphs(blk, structure, index=0): + # get block name + name = type(blk).__name__.lower() + # lookup previous encoded name based on block type and ID + mdl = structure[name+str(index)] + index += 1 + if isinstance(blk, HybridBlock): + if mdl['hybridized']: + # restore in/out formats + blk._in_format = mdl['in_format'] + blk._out_format = mdl['out_format'] + # get saved symbol + out = fromjson(mdl['symbol']) + syms = [] + # recreate inputs for this symbol + for inp in mdl['inputs']: + syms.append(fromjson(inp)) + # reset cached_graph and active status + blk._cached_graph = (syms, out) + blk._active = True + # reload param uuids + pmap = mdl['params'] + for p, uuid in pmap.items(): + param = blk.params[p] + param._uuid = uuid + # recursively reload children + for child in blk._children.values(): + index = _load_cached_graphs(child(), mdl, index) + # return latest index (ie. block count) + return index + + # load top-level block + _load_cached_graphs(self, model) + # load params + self.load_parameters('MyModel-model.params') + def hybridize(self, active=True, **kwargs): """ Please refer description of HybridBlock hybridize(). """ diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 9082a62a4e99..3ef6281faea5 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -46,7 +46,7 @@ from ..profiler import scope as _profiler_scope from ..profiler import _current_scope as _current_profiler_scope -__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json", +__all__ = ["Symbol", "var", "Variable", "Group", "load", "fromjson", "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange", "linspace", "histogram", "split_v2"] @@ -1400,7 +1400,7 @@ def tojson(self, remove_amp_cast=True): See Also -------- - symbol.load_json : Used to load symbol from JSON string. + symbol.fromjson : Used to load symbol from JSON string. """ json_str = ctypes.c_char_p() if remove_amp_cast: @@ -2821,7 +2821,7 @@ def load(fname): return Symbol(handle) -def load_json(json_str): +def fromjson(json_str): """Loads symbol from json string. Parameters diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 76fad4f5e4ac..e538dd4b1633 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -1135,7 +1135,7 @@ def verify_foreach(step, in_syms, state_syms, free_syms, out.extend(states) out = mx.sym.Group(out) js_1 = out.tojson() - out = mx.sym.load_json(js_1) + out = mx.sym.fromjson(js_1) js_2 = out.tojson() assert js_1 == js_2 arr_grads = [] @@ -1457,7 +1457,7 @@ def step_nd(in1, states): out = mx.sym.broadcast_add(out, states[0]) js_1 = out.tojson() - out = mx.sym.load_json(js_1) + out = mx.sym.fromjson(js_1) js_2 = out.tojson() assert js_1 == js_2 diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index c0c14b7add8f..5f132f96c2de 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -415,7 +415,7 @@ def dynamic_reshape_testcases(src_shape, shape_arg, dst_shape): shape = mx.sym.Variable('shape') net = mx.sym.contrib.dynamic_reshape(data, shape) js = net.tojson() - net = mx.sym.load_json(js) + net = mx.sym.fromjson(js) dat_npy = np.random.rand(*src_shape) grad_npy = np.random.rand(*dst_shape) args = { diff --git a/tests/python/unittest/test_gluon_save.py b/tests/python/unittest/test_gluon_save.py new file mode 100644 index 000000000000..c17df63dc64f --- /dev/null +++ b/tests/python/unittest/test_gluon_save.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx + +def test_save(): + class MyBlock(mx.gluon.Block): + def __init__(self, **kwargs): + super(MyBlock, self).__init__(**kwargs) + self.layers = [] + def add(self, block): + self.layers.append(block) + self.register_child(block) + def forward(self, x, *args): + out = (x,) + args + for block in self._children.values(): + out = block()(*out) + return out + + def createNet(): + inside = MyBlock() + dense = mx.gluon.nn.Dense(10) + inside.add(dense) + net = MyBlock() + net.add(inside) + net.add(mx.gluon.nn.Dense(10)) + return net + + # create and initialize model + net1 = createNet() + net1.initialize() + # hybridize (the hybridizeable blocks, ie. the Dense layers) + net1.hybridize() + x = mx.nd.zeros((1,10)) + out1 = net1(x) + + # save hybridized model + net1.save('MyModel') + + # create a new model, uninitialized + net2 = createNet() + # reload hybridized model + net2.load('MyModel') + # run inference again + out2 = net2(x) + mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy()) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 31a6aa13864c..5034b07d9ddc 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2269,7 +2269,7 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape): net = mx.sym.Variable("data") net = mx.sym.Reshape(net, shape=shape_args, reverse=reverse) js = net.tojson() - net = mx.sym.load_json(js) + net = mx.sym.fromjson(js) _, output_shape, __ = net.infer_shape(data=src_shape) assert output_shape[0] == dst_shape, \ 'Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = %s, ' \ @@ -2308,7 +2308,7 @@ def test_reshape_old(): net = mx.sym.Variable("data") net = mx.sym.Reshape(net, target_shape=(2, 0)) js = net.tojson() - net = mx.sym.load_json(js) + net = mx.sym.fromjson(js) _, output_shape, __ = net.infer_shape(data=(2, 3, 5, 5)) assert(output_shape[0] == (2, 75)) # Test for Flatten @@ -2329,7 +2329,7 @@ def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shap rhs = mx.sym.Variable("rhs") net = mx.sym.reshape_like(lhs, rhs, lhs_begin=lbeg, lhs_end=lend, rhs_begin=rbeg, rhs_end=rend) js = net.tojson() - net = mx.sym.load_json(js) + net = mx.sym.fromjson(js) _, output_shape, __ = net.infer_shape(lhs=lhs_shape, rhs=rhs_shape) assert output_shape[0] == dst_shape, \ @@ -2370,7 +2370,7 @@ def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shap rhs = mx.sym.Variable("rhs") net = mx.sym.reshape_like(lhs, rhs) js = net.tojson() - net = mx.sym.load_json(js) + net = mx.sym.fromjson(js) _, output_shape, __ = net.infer_shape(lhs=(40, 30), rhs=(30,20,2)) assert(output_shape[0] == (30,20,2))