Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Save/Load Gluon Blocks & HybridBlocks (#19564)
Browse files Browse the repository at this point in the history
  • Loading branch information
samskalicky authored Dec 14, 2020
1 parent ed4aa23 commit 765297b
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 11 deletions.
140 changes: 139 additions & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
import contextvars

import re
import json
import numpy as np

from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB
from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \
profiler as _profiler, context as _context
from ..symbol.numpy import _symbol as np_symbol
from ..symbol import Symbol
from ..symbol import Symbol, fromjson
from ..ndarray import NDArray
from .parameter import Parameter, DeferredInitializationError
from .utils import _indent, _brief_print_list, HookHandle, shape_is_known
Expand Down Expand Up @@ -573,6 +574,143 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
for v in params.values():
v.initialize(None, ctx, init, force_reinit=force_reinit)

def save(self, prefix):
"""Save the model architecture and parameters to load again later
Saves the model architecture as a nested dictionary where each Block
in the model is a dictionary and its children are sub-dictionaries.
Each Block is uniquely identified by Block class name and a unique ID.
We save each Block's parameter UUID to restore later in order to match
the saved parameters.
Recursively traverses a Block's children in order (since its an
OrderedDict) and uses the unique ID to denote that specific Block.
Assumes that the model is created in an identical order every time.
If the model is not able to be recreated deterministically do not
use this set of APIs to save/load your model.
For HybridBlocks, the cached_graph is saved (Symbol & inputs) if
it has already been hybridized.
Parameters
----------
prefix : str
The prefix to use in filenames for saving this model:
<prefix>-model.json and <prefix>-model.params
"""
# create empty model structure
model = {}
def _save_cached_graphs(blk, structure, index=0):
# create new entry for this block
mdl = {}
# encode unique name based on block type and ID
name = type(blk).__name__.lower()
structure[name+str(index)] = mdl
index += 1
if isinstance(blk, HybridBlock):
if blk._cached_graph:
# save in/out formats
mdl['in_format'] = blk._in_format
mdl['out_format'] = blk._out_format
# save cached graph & input symbols
syms, out = blk._cached_graph
mdl_syms = []
for sym in syms:
mdl_syms.append(sym.tojson())
mdl['inputs'] = mdl_syms
mdl['symbol'] = out.tojson()
mdl['hybridized'] = True
else:
mdl['hybridized'] = False
# save param uuids
pmap = {}
mdl['params'] = pmap
pnames = list(blk.params.keys())
for p in pnames:
param = blk.params[p]
pmap[p] = param._uuid
# recursively save children
for child in blk._children.values():
index = _save_cached_graphs(child(), mdl, index)
# return latest index (ie. block count)
return index

# save top-level block
_save_cached_graphs(self, model)
# save model
with open(prefix+'-model.json', 'w') as fp:
json.dump(model, fp)
# save params
self.save_parameters('MyModel-model.params')

def load(self, prefix):
"""Load a model saved using the `save` API
Reconfigures a model using the saved configuration. This function
does not regenerate the model architecture. It resets each Block's
parameter UUIDs as they were when saved in order to match the names of the
saved parameters.
This function assumes the Blocks in the model were created in the same
order they were when the model was saved. This is because each Block is
uniquely identified by Block class name and a unique ID in order (since
its an OrderedDict) and uses the unique ID to denote that specific Block.
Assumes that the model is created in an identical order every time.
If the model is not able to be recreated deterministically do not
use this set of APIs to save/load your model.
For HybridBlocks, the cached_graph (Symbol & inputs) and settings are
restored if it had been hybridized before saving.
Parameters
----------
prefix : str
The prefix to use in filenames for loading this model:
<prefix>-model.json and <prefix>-model.params
"""
# load model json from file
with open(prefix+'-model.json') as fp:
model = json.load(fp)

def _load_cached_graphs(blk, structure, index=0):
# get block name
name = type(blk).__name__.lower()
# lookup previous encoded name based on block type and ID
mdl = structure[name+str(index)]
index += 1
if isinstance(blk, HybridBlock):
if mdl['hybridized']:
# restore in/out formats
blk._in_format = mdl['in_format']
blk._out_format = mdl['out_format']
# get saved symbol
out = fromjson(mdl['symbol'])
syms = []
# recreate inputs for this symbol
for inp in mdl['inputs']:
syms.append(fromjson(inp))
# reset cached_graph and active status
blk._cached_graph = (syms, out)
blk._active = True
# reload param uuids
pmap = mdl['params']
for p, uuid in pmap.items():
param = blk.params[p]
param._uuid = uuid
# recursively reload children
for child in blk._children.values():
index = _load_cached_graphs(child(), mdl, index)
# return latest index (ie. block count)
return index

# load top-level block
_load_cached_graphs(self, model)
# load params
self.load_parameters('MyModel-model.params')

def hybridize(self, active=True, **kwargs):
""" Please refer description of HybridBlock hybridize().
"""
Expand Down
6 changes: 3 additions & 3 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ..profiler import scope as _profiler_scope
from ..profiler import _current_scope as _current_profiler_scope

__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
__all__ = ["Symbol", "var", "Variable", "Group", "load", "fromjson",
"pow", "power", "maximum", "minimum", "hypot", "eye", "zeros",
"ones", "full", "arange", "linspace", "histogram", "split_v2"]

Expand Down Expand Up @@ -1400,7 +1400,7 @@ def tojson(self, remove_amp_cast=True):
See Also
--------
symbol.load_json : Used to load symbol from JSON string.
symbol.fromjson : Used to load symbol from JSON string.
"""
json_str = ctypes.c_char_p()
if remove_amp_cast:
Expand Down Expand Up @@ -2821,7 +2821,7 @@ def load(fname):
return Symbol(handle)


def load_json(json_str):
def fromjson(json_str):
"""Loads symbol from json string.
Parameters
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_contrib_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ def verify_foreach(step, in_syms, state_syms, free_syms,
out.extend(states)
out = mx.sym.Group(out)
js_1 = out.tojson()
out = mx.sym.load_json(js_1)
out = mx.sym.fromjson(js_1)
js_2 = out.tojson()
assert js_1 == js_2
arr_grads = []
Expand Down Expand Up @@ -1457,7 +1457,7 @@ def step_nd(in1, states):
out = mx.sym.broadcast_add(out, states[0])

js_1 = out.tojson()
out = mx.sym.load_json(js_1)
out = mx.sym.fromjson(js_1)
js_2 = out.tojson()
assert js_1 == js_2

Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_contrib_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def dynamic_reshape_testcases(src_shape, shape_arg, dst_shape):
shape = mx.sym.Variable('shape')
net = mx.sym.contrib.dynamic_reshape(data, shape)
js = net.tojson()
net = mx.sym.load_json(js)
net = mx.sym.fromjson(js)
dat_npy = np.random.rand(*src_shape)
grad_npy = np.random.rand(*dst_shape)
args = {
Expand Down
60 changes: 60 additions & 0 deletions tests/python/unittest/test_gluon_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx

def test_save():
class MyBlock(mx.gluon.Block):
def __init__(self, **kwargs):
super(MyBlock, self).__init__(**kwargs)
self.layers = []
def add(self, block):
self.layers.append(block)
self.register_child(block)
def forward(self, x, *args):
out = (x,) + args
for block in self._children.values():
out = block()(*out)
return out

def createNet():
inside = MyBlock()
dense = mx.gluon.nn.Dense(10)
inside.add(dense)
net = MyBlock()
net.add(inside)
net.add(mx.gluon.nn.Dense(10))
return net

# create and initialize model
net1 = createNet()
net1.initialize()
# hybridize (the hybridizeable blocks, ie. the Dense layers)
net1.hybridize()
x = mx.nd.zeros((1,10))
out1 = net1(x)

# save hybridized model
net1.save('MyModel')

# create a new model, uninitialized
net2 = createNet()
# reload hybridized model
net2.load('MyModel')
# run inference again
out2 = net2(x)
mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy())
8 changes: 4 additions & 4 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2269,7 +2269,7 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape):
net = mx.sym.Variable("data")
net = mx.sym.Reshape(net, shape=shape_args, reverse=reverse)
js = net.tojson()
net = mx.sym.load_json(js)
net = mx.sym.fromjson(js)
_, output_shape, __ = net.infer_shape(data=src_shape)
assert output_shape[0] == dst_shape, \
'Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = %s, ' \
Expand Down Expand Up @@ -2308,7 +2308,7 @@ def test_reshape_old():
net = mx.sym.Variable("data")
net = mx.sym.Reshape(net, target_shape=(2, 0))
js = net.tojson()
net = mx.sym.load_json(js)
net = mx.sym.fromjson(js)
_, output_shape, __ = net.infer_shape(data=(2, 3, 5, 5))
assert(output_shape[0] == (2, 75))
# Test for Flatten
Expand All @@ -2329,7 +2329,7 @@ def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shap
rhs = mx.sym.Variable("rhs")
net = mx.sym.reshape_like(lhs, rhs, lhs_begin=lbeg, lhs_end=lend, rhs_begin=rbeg, rhs_end=rend)
js = net.tojson()
net = mx.sym.load_json(js)
net = mx.sym.fromjson(js)
_, output_shape, __ = net.infer_shape(lhs=lhs_shape, rhs=rhs_shape)

assert output_shape[0] == dst_shape, \
Expand Down Expand Up @@ -2370,7 +2370,7 @@ def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shap
rhs = mx.sym.Variable("rhs")
net = mx.sym.reshape_like(lhs, rhs)
js = net.tojson()
net = mx.sym.load_json(js)
net = mx.sym.fromjson(js)
_, output_shape, __ = net.infer_shape(lhs=(40, 30), rhs=(30,20,2))
assert(output_shape[0] == (30,20,2))

Expand Down

0 comments on commit 765297b

Please sign in to comment.