Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Save/Load Gluon Blocks & HybridBlocks #19565

Merged
merged 10 commits into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from ..base import c_array, c_str, mx_uint, c_str_array
from ..base import NDArrayHandle, SymbolHandle
from ..symbol import Symbol
from ..symbol import load as sym_load
from .. import ndarray
from ..ndarray import load as nd_load
from ..ndarray import save as nd_save
Expand Down Expand Up @@ -376,7 +375,7 @@ def _load_sym(sym, logger=None):
symbol_file_path = os.path.join(cur_path, sym)
if logger:
logger.info('Loading symbol from file %s' % symbol_file_path)
return sym_load(symbol_file_path)
return Symbol.load(symbol_file_path)
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(sym, Symbol):
return sym
else:
Expand Down
146 changes: 145 additions & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
import copy
import warnings
import re
import json
from collections import OrderedDict, defaultdict
import numpy as np

from ..base import mx_real_t, MXNetError
from .. import symbol, ndarray, initializer, np_symbol
from ..symbol import Symbol
from ..symbol import Symbol, fromjson
from ..ndarray import NDArray
from .. import name as _name
from .parameter import Parameter, ParameterDict, DeferredInitializationError
Expand Down Expand Up @@ -661,6 +662,149 @@ def hybridize(self, active=True, **kwargs):
for cld in self._children.values():
cld.hybridize(active, **kwargs)

def save(self, prefix):
"""Save the model architecture and parameters to load again later

Saves the model architecture as a nested dictionary where each Block
in the model is a dictionary and its children are sub-dictionaries.

Each Block is uniquely identified by Block class name and a unique ID.
We save the child's name that that parent uses for it to restore later
in order to match the saved parameters.

Recursively traverses a Block's children in order (since its an
OrderedDict) and uses the unique ID to denote that specific Block.
Assumes that the model is created in an identical order every time.
If the model is not able to be recreated deterministically do not
use this set of APIs to save/load your model.

For HybridBlocks, the cached_graph (Symbol & inputs) is saved if
it has already been hybridized.

Parameters
----------
prefix : str
The prefix to use in filenames for saving this model:
<prefix>-model.json and <prefix>-model.params
"""
# create empty model structure
model = {}
def _save_cached_graphs(blk, index, structure):
# create new entry for this block
mdl = {'orig_name': blk.name}
# encode unique name based on block type and ID
name = type(blk).__name__.lower()
structure[name+str(index[0])] = mdl
if isinstance(blk, HybridBlock):
if blk._cached_graph:
# save in/out formats
mdl['in_format'] = blk._in_format
mdl['out_format'] = blk._out_format
# save cached graph & input symbols
syms, out = blk._cached_graph
mdl_syms = []
for sym in syms:
mdl_syms.append(sym.tojson())
mdl['inputs'] = mdl_syms
mdl['symbol'] = out.tojson()
mdl['hybridized'] = True
else:
mdl['hybridized'] = False
children = dict()
mdl['children'] = children
# recursively save children
for ch_name, child in blk._children.items():
index[0] += 1
# save child's original name in this block's map
children[child.name] = ch_name
_save_cached_graphs(child, index, mdl)
# save top-level block
index = [0]
_save_cached_graphs(self, index, model)
# save model
fp = open(prefix+'-model.json', 'w')
json.dump(model, fp)
fp.close()
# save params
self.save_parameters(prefix+'-model.params')

def load(self, prefix):
"""Load a model saved using the `save` API

Reconfigures a model using the saved configuration. This function
does not regenerate the model architecture. It resets the children's
names as they were when saved in order to match the names of the
saved parameters.

This function assumes the Blocks in the model were created in the same
order they were when the model was saved. This is because each Block is
uniquely identified by Block class name and a unique ID in order (since
its an OrderedDict) and uses the unique ID to denote that specific Block.
Assumes that the model is created in an identical order every time.
If the model is not able to be recreated deterministically do not
use this set of APIs to save/load your model.

For HybridBlocks, the cached_graph (Symbol & inputs) and settings are
restored if it had been hybridized before saving.

Parameters
----------
prefix : str
The prefix to use in filenames for loading this model:
<prefix>-model.json and <prefix>-model.params
"""
# load model json from file
fp = open(prefix+'-model.json')
model = json.load(fp)
fp.close()
def _load_cached_graphs(blk, index, structure):
# get block name
name = type(blk).__name__.lower()
# lookup previous encoded name based on block type and ID
mdl = structure[name+str(index[0])]
# rename block to what it was when saved
blk._name = mdl['orig_name']
if isinstance(blk, HybridBlock):
if mdl['hybridized']:
# restore in/out formats
blk._in_format = mdl['in_format']
blk._out_format = mdl['out_format']
# get saved symbol
out = fromjson(mdl['symbol'])
syms = []
# recreate inputs for this symbol
for inp in mdl['inputs']:
syms.append(fromjson(inp))
# reset cached_graph and active status
blk._cached_graph = (syms, out)
blk._active = True
# rename params with updated block name
pnames = list(blk.params.keys())
for p in pnames:
param = blk.params._params[p]
new_name = blk.name +'_'+ p[len(blk.params._prefix):]
blk.params._params.pop(p)
blk.params._params[new_name] = param
# recursively reload children
for ch_name, child in blk._children.items():
index[0] += 1
_load_cached_graphs(child, index, mdl)
# current set of child names
ch_names = list(blk._children.keys())
# original child names
children = mdl['children']
# loop and remap children with original names
for ch_name in ch_names:
child = blk._children[ch_name]
blk._children.pop(ch_name)
orig_name = children[child.name]
blk._children[orig_name] = child
# load top-level block
index = [0]
_load_cached_graphs(self, index, model)
# load params
self.load_parameters(prefix+'-model.params')

def cast(self, dtype):
"""Cast this Block to use another data type.

Expand Down
12 changes: 6 additions & 6 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ._internal import SymbolBase, _set_symbol_class
from ..util import is_np_shape

__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
__all__ = ["Symbol", "var", "Variable", "Group", "load", "fromjson",
"pow", "power", "maximum", "minimum", "hypot", "eye", "zeros",
"ones", "full", "arange", "linspace", "histogram", "split_v2"]

Expand Down Expand Up @@ -1369,7 +1369,7 @@ def tojson(self, remove_amp_cast=True):

See Also
--------
symbol.load_json : Used to load symbol from JSON string.
symbol.fromjson : Used to load symbol from JSON string.
"""
json_str = ctypes.c_char_p()
if remove_amp_cast:
Expand Down Expand Up @@ -3035,9 +3035,9 @@ def load(fname):
fname : str
The name of the file, examples:

- `s3://my-bucket/path/my-s3-symbol`
- `hdfs://my-bucket/path/my-hdfs-symbol`
- `/path-to/my-local-symbol`
- `s3://my-bucket/path/my-s3-symbol`
- `hdfs://my-bucket/path/my-hdfs-symbol`
- `/path-to/my-local-symbol`
samskalicky marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Expand All @@ -3055,7 +3055,7 @@ def load(fname):
return Symbol(handle)


def load_json(json_str):
def fromjson(json_str):
"""Loads symbol from json string.

Parameters
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_contrib_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,7 @@ def verify_foreach(step, in_syms, state_syms, free_syms,
out.extend(states)
out = mx.sym.Group(out)
js_1 = out.tojson()
out = mx.sym.load_json(js_1)
out = mx.sym.fromjson(js_1)
js_2 = out.tojson()
assert js_1 == js_2
arr_grads = []
Expand Down Expand Up @@ -1556,7 +1556,7 @@ def step_nd(in1, states):
out = mx.sym.broadcast_add(out, states[0])

js_1 = out.tojson()
out = mx.sym.load_json(js_1)
out = mx.sym.fromjson(js_1)
js_2 = out.tojson()
assert js_1 == js_2

Expand Down Expand Up @@ -1631,7 +1631,7 @@ def sym_group(out):
out = mx.sym.contrib.foreach(step, data, init_states)
out = sym_group(out)
js_1 = out.tojson()
out = mx.sym.load_json(js_1)
out = mx.sym.fromjson(js_1)
js_2 = out.tojson()
assert js_1 == js_2
e1 = out.bind(ctx=default_context(), args=args1, args_grad=args_grad1)
Expand All @@ -1647,7 +1647,7 @@ def sym_group(out):
unroll_outs.extend(states)
out = mx.sym.Group(unroll_outs)
js_1 = out.tojson()
out = mx.sym.load_json(js_1)
out = mx.sym.fromjson(js_1)
js_2 = out.tojson()
assert js_1 == js_2
e2 = out.bind(ctx=default_context(), args=args2, args_grad=args_grad2)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,7 +1917,7 @@ def test_legacy_save_params():
a = net(mx.sym.var('data'))
a.save('test.json')
net.save_params('test.params')
model = gluon.nn.SymbolBlock(outputs=mx.sym.load_json(open('test.json', 'r').read()),
model = gluon.nn.SymbolBlock(outputs=mx.sym.fromjson(open('test.json', 'r').read()),
inputs=mx.sym.var('data'))
model.load_params('test.params', ctx=mx.cpu())

Expand Down
64 changes: 64 additions & 0 deletions tests/python/unittest/test_gluon_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
from common import with_seed

@with_seed()
def test_save():
class MyBlock(mx.gluon.nn.Block):
def __init__(self, **kwargs):
super(MyBlock, self).__init__(**kwargs)
def add(self, block):
self._children[block.name + str(len(self._children))] = block
def forward(self, x, *args):
out = (x,) + args
for block in self._children.values():
out = block(*out)
return out

def createNet():
inside = MyBlock()
dense = mx.gluon.nn.Dense(10)
inside.add(dense)
net = MyBlock()
net.add(inside)
net.add(mx.gluon.nn.Dense(10))
return net

# create and initialize model
net1 = createNet()
net1.initialize()
# hybridize (the hybridizeable blocks, ie. the Dense layers)
net1.hybridize()
x = mx.nd.zeros((1,10))
out1 = net1(x)

# save hybridized model
net1.save('MyModel')

# create a new model, uninitialized
net2 = createNet()
# reload hybridized model
net2.load('MyModel')
# run inference again
out2 = net2(x)
mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy())

if __name__ == '__main__':
import nose
nose.runmodule()
8 changes: 4 additions & 4 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2659,7 +2659,7 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape):
net = mx.sym.Variable("data")
net = mx.sym.Reshape(net, shape=shape_args, reverse=reverse)
js = net.tojson()
net = mx.sym.load_json(js)
net = mx.sym.fromjson(js)
_, output_shape, __ = net.infer_shape(data=src_shape)
assert output_shape[0] == dst_shape, \
'Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = %s, ' \
Expand Down Expand Up @@ -2728,7 +2728,7 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape):
net = mx.sym.Variable("data")
net = mx.sym.Reshape(net, target_shape=(2, 0))
js = net.tojson()
net = mx.sym.load_json(js)
net = mx.sym.fromjson(js)
_, output_shape, __ = net.infer_shape(data=(2, 3, 5, 5))
assert(output_shape[0] == (2, 75))
# Test for Flatten
Expand All @@ -2750,7 +2750,7 @@ def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shap
rhs = mx.sym.Variable("rhs")
net = mx.sym.reshape_like(lhs, rhs, lhs_begin=lbeg, lhs_end=lend, rhs_begin=rbeg, rhs_end=rend)
js = net.tojson()
net = mx.sym.load_json(js)
net = mx.sym.fromjson(js)
_, output_shape, __ = net.infer_shape(lhs=lhs_shape, rhs=rhs_shape)

assert output_shape[0] == dst_shape, \
Expand Down Expand Up @@ -2791,7 +2791,7 @@ def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shap
rhs = mx.sym.Variable("rhs")
net = mx.sym.reshape_like(lhs, rhs)
js = net.tojson()
net = mx.sym.load_json(js)
net = mx.sym.fromjson(js)
_, output_shape, __ = net.infer_shape(lhs=(40, 30), rhs=(30,20,2))
assert(output_shape[0] == (30,20,2))

Expand Down