Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Refactor Gluon parameter serialization format #18749

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 61 additions & 52 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from collections import OrderedDict, defaultdict
import contextlib
import contextvars
import zipfile

import re
import numpy as np
Expand Down Expand Up @@ -336,41 +337,43 @@ def _collect_params_with_prefix(self, prefix='', select=None):
ret.update(child()._collect_params_with_prefix(prefix + name, select))
return ret

def save_parameters(self, filename, deduplicate=False):
"""Save parameters to file.
def save_parameters(self, filename):
"""Save parameters to file based on numpy .npz format.

Saved parameters can only be loaded with `load_parameters`. Note that this
method only saves parameters, not model structure. If you want to save
model structures, please use :py:meth:`HybridBlock.export`.
Saved parameters can be loaded with `load_parameters` or numpy.load.
Note that this method only saves parameters, not model structure. If
you want to save model structures, please use :py:meth:`HybridBlock.export`.

Parameters
----------
filename : str
Path to file.
deduplicate : bool, default False
If True, save shared parameters only once. Otherwise, if a Block
contains multiple sub-blocks that share parameters, each of the
shared parameters will be separately saved for every sub-block.
filename : str or file
Either the filename (string) or an open file (file-like object)
where the data will be saved.

References
----------
`Saving and Loading Gluon Models \
<https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html>`_
"""
params = self._collect_params_with_prefix()

if deduplicate:
# Shared parameters are stored only a single time as of MXNet 1.6.
# Shared parameters are registered under multiple prefixes returned by
# _collect_params_with_prefix. We select a single one and only store
# it. In load_parameters it is sufficient for a shared parameter to
# only set it for a single prefix.
reverse_params = {v: k for k, v in params.items()}
params = {v: k for k, v in reverse_params.items()}

arg_dict = {key: val._reduce() for key, val in params.items()}
save_fn = _mx_npx.save if is_np_array() else ndarray.save
save_fn(filename, arg_dict)
"""
params_to_names = {}
for name, param in self._collect_params_with_prefix().items():
params_to_names.setdefault(param, []).append(name)
params = {}
for param, names in params_to_names.items():
assert len(names)
params[names[0]] = param._reduce().asnumpy()
for name in names[1:]:
# Shared parameters are known under multiple names. We save the
# parameter according to its first name and save the mapping
# to the first name for the other names.
params[name] = names[0]

if hasattr(filename, 'write'): # file object
return np.savez(filename, **params)
with open(filename, 'w+b') as f: # filename
# Avoid np.savez modifying the filename by passing file object
return np.savez(f, **params)

def load_parameters(self, filename, ctx=None, allow_missing=False,
ignore_extra=False, cast_dtype=False, dtype_source='current'):
Expand Down Expand Up @@ -399,28 +402,32 @@ def load_parameters(self, filename, ctx=None, allow_missing=False,
`Saving and Loading Gluon Models \
<https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html>`_
"""
if is_np_array():
# failure may happen when loading parameters saved as NDArrays within
# NumPy semantics. Check the failure type and recover from it if it happens.
try:
loaded = _mx_npx.load(filename)
except MXNetError as e:
err_msg = str(e)
if 'is_np_shape' in err_msg:
# Loading failure due to parameters saved without numpy semantics.
# Temporarily disable numpy semantics and load parameters. After it's
# done, resume the numpy semantics. This is fine because the cases
# numpy ndarray covers is a superset of the legacy ndarray's.
with np_array(False):
with np_shape(False):
loaded_nds = ndarray.load(filename)
assert isinstance(loaded_nds, dict),\
'expecting a dict type, got {}'.format(str(type(loaded_nds)))
loaded = {k: loaded_nds[k].as_np_ndarray() for k in loaded_nds}
else:
raise ValueError(err_msg)
else:
loaded = ndarray.load(filename)
if zipfile.is_zipfile(filename):
loaded = np.load(filename)
loaded = {n: (p if not isinstance(p, str) else loaded[p]) for n, p in loaded.items()}
else: # Try MXNet deprecated format
if is_np_array():
# failure may happen when loading parameters saved as NDArrays within
# NumPy semantics. Check the failure type and recover from it if it happens.
try:
loaded = _mx_npx.load(filename)
except MXNetError as e:
err_msg = str(e)
if 'is_np_shape' in err_msg:
# Loading failure due to parameters saved without numpy semantics.
# Temporarily disable numpy semantics and load parameters. After it's
# done, resume the numpy semantics. This is fine because the cases
# numpy ndarray covers is a superset of the legacy ndarray's.
with np_array(False):
with np_shape(False):
loaded_nds = ndarray.load(filename)
assert isinstance(loaded_nds, dict),\
'expecting a dict type, got {}'.format(str(type(loaded_nds)))
loaded = {k: loaded_nds[k].as_np_ndarray() for k in loaded_nds}
else:
raise ValueError(err_msg)
else:
loaded = ndarray.load(filename)

if not loaded:
return
Expand Down Expand Up @@ -483,7 +490,8 @@ def load_dict(self, param_dict, ctx=None, allow_missing=False,
if name in params:
param = loaded[name]
if isinstance(param, np.ndarray):
param = _mx_np.array(param) if is_np_array() else nd.array(param)
param = _mx_np.array(param, dtype=param.dtype) if is_np_array() \
else nd.array(param, dtype=param.dtype)
params[name]._load_init(param, ctx, cast_dtype=cast_dtype, dtype_source=dtype_source)

def register_child(self, block, name=None):
Expand Down Expand Up @@ -1343,13 +1351,14 @@ def export(self, path, epoch=0, remove_amp_cast=True):
for is_arg, name, param in self._cached_op_args:
if not is_arg:
if name in arg_names:
arg_dict['arg:{}'.format(name)] = param._reduce()
arg_dict['arg:{}'.format(name)] = param._reduce().asnumpy()
else:
assert name in aux_names
arg_dict['aux:{}'.format(name)] = param._reduce()
save_fn = _mx_npx.save if is_np_array() else ndarray.save
arg_dict['aux:{}'.format(name)] = param._reduce().asnumpy()
params_filename = '%s-%04d.params'%(path, epoch)
save_fn(params_filename, arg_dict)
with open(params_filename, 'w+b') as f:
# Avoid np.savez modifying the filename by passing file object
np.savez(f, **arg_dict)
return (sym_filename, params_filename)

def register_op_hook(self, callback, monitor_all=False):
Expand Down
7 changes: 4 additions & 3 deletions python/mxnet/numpy_extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
# under the License.

"""Util functions for the numpy module."""



import ctypes
import warnings

from .. util import is_np_array, is_np_shape
from .. base import _LIB, check_call, string_types, c_str_array, DLPackHandle
from .. base import c_handle_array, c_str, mx_uint, NDArrayHandle, py_str
Expand Down Expand Up @@ -61,6 +60,7 @@ def save(file, arr):
This function can only be called within numpy semantics, i.e., `npx.is_np_shape()`
and `npx.is_np_array()` must both return true.
"""
warnings.warn("MXNet parameter serialization format is deprecated in favor of NumPy format.", DeprecationWarning)
if not (is_np_shape() and is_np_array()):
raise ValueError('Cannot save `mxnet.numpy.ndarray` in legacy mode. Please activate'
' numpy semantics by calling `npx.set_np()` in the global scope'
Expand Down Expand Up @@ -109,6 +109,7 @@ def load(file):
This function can only be called within numpy semantics, i.e., `npx.is_np_shape()`
and `npx.is_np_array()` must both return true.
"""
warnings.warn("MXNet parameter serialization format is deprecated in favor of NumPy format.", DeprecationWarning)
if not (is_np_shape() and is_np_array()):
raise ValueError('Cannot load `mxnet.numpy.ndarray` in legacy mode. Please activate'
' numpy semantics by calling `npx.set_np()` in the global scope'
Expand Down