Skip to content

Commit

Permalink
Update 2.0 Save/Load API names/arguments/doc examples (#27138)
Browse files Browse the repository at this point in the history
* Update set_dict method name & add aliases (#26700)

* update set_dict method name & add aliases

* fix var name error

* fix alias formats

* use set_state_dict in unittest

* add decorator solve compatible problem

* polish decorator

* replace layer set_state_dict by patched method

* remove import monkey path layer

* fix import function error

* add unittest for coverage

* Support load state dict form `inference model` format save result (#26718)

* support load infer model format state dict

* add unittests

* remove keep name table

* recolve circle inport

* fix compatible problem

* recover unittest

* polish doc and comment

* Change jit.save/load configs to config & update code examples (#27056)

* change configs to config & update examples

* fix deprecate decorator conflict
  • Loading branch information
chenwhql authored Sep 8, 2020
1 parent 0072490 commit 2986184
Show file tree
Hide file tree
Showing 21 changed files with 814 additions and 628 deletions.
Empty file added paddle/http.log
Empty file.
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@
from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS
from .framework import SaveLoadConfig #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS

from .framework import NoamDecay #DEFINE_ALIAS
Expand Down
168 changes: 98 additions & 70 deletions python/paddle/fluid/dygraph/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,54 @@

import os
import collections
import functools
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer
import pickle
import six
from . import learning_rate_scheduler
import warnings
from .. import core
from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME, _load_persistable_vars
from .base import guard
from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers

__all__ = [
'save_dygraph',
'load_dygraph',
]


# NOTE(chenweihang): deprecate load_dygraph's argument keep_name_table,
# ensure compatibility when user still use keep_name_table argument
def deprecate_keep_name_table(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
def __warn_and_build_configs__(keep_name_table):
warnings.warn(
"The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.",
DeprecationWarning)
config = SaveLoadConfig()
config.keep_name_table = keep_name_table
return config

# deal with arg `keep_name_table`
if len(args) > 1 and isinstance(args[1], bool):
args = list(args)
args[1] = __warn_and_build_configs__(args[1])
# deal with kwargs
elif 'keep_name_table' in kwargs:
kwargs['config'] = __warn_and_build_configs__(kwargs[
'keep_name_table'])
kwargs.pop('keep_name_table')
else:
# do nothing
pass

return func(*args, **kwargs)

return wrapper


@dygraph_only
def save_dygraph(state_dict, model_path):
'''
Expand Down Expand Up @@ -100,41 +134,56 @@ def save_dygraph(state_dict, model_path):

# TODO(qingqing01): remove dygraph_only to support loading static model.
# maybe need to unify the loading interface after 2.0 API is ready.
#@dygraph_only
def load_dygraph(model_path, keep_name_table=False):
# @dygraph_only
@deprecate_save_load_configs
@deprecate_keep_name_table
def load_dygraph(model_path, config=None):
'''
:api_attr: imperative
Load parameter state_dict from disk.
Load parameter state dict from disk.
.. note::
Due to some historical reasons, if you load ``state_dict`` from the saved
result of `paddle.io.save_inference_model`, the structured variable name
will cannot be restored. You need to set the argument `use_structured_name=False`
when using `Layer.set_state_dict` later.
Args:
model_path(str) : The file prefix store the state_dict. (The path should Not contain suffix '.pdparams')
keep_name_table(bool, optional) : Whether keep structed name to parameter name conversion table in output dict.
Default : False
model_path(str) : The file prefix store the state_dict.
(The path should Not contain suffix '.pdparams')
config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig`
object that specifies additional configuration options, these options
are for compatibility with ``jit.save/io.save_inference_model`` formats.
Default None.
Returns:
state_dict(dict) : the dict store the state_dict
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
paddle.disable_static()
state_dict = emb.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")
emb = paddle.nn.Embedding([10, 10])
adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
parameter_list = emb.parameters() )
state_dict = adam.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")
state_dict = emb.state_dict()
paddle.save(state_dict, "paddle_dy")
para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy")
scheduler = paddle.optimizer.lr_scheduler.NoamLR(
d_model=0.01, warmup_steps=100, verbose=True)
adam = paddle.optimizer.Adam(
learning_rate=scheduler,
parameters=emb.parameters())
state_dict = adam.state_dict()
paddle.save(state_dict, "paddle_dy")
'''
para_state_dict, opti_state_dict = paddle.load("paddle_dy")
'''
# deal with argument `model_path`
model_prefix = model_path
if model_prefix.endswith(".pdparams"):
model_prefix = model_prefix[:-9]
Expand All @@ -145,66 +194,45 @@ def load_dygraph(model_path, keep_name_table=False):
opti_dict = None
params_file_path = model_prefix + ".pdparams"
opti_file_path = model_prefix + ".pdopt"

# deal with argument `configs`
configs = config
if configs is None:
configs = SaveLoadConfig()

if not os.path.exists(params_file_path) and not os.path.exists(
opti_file_path):
# Load state dict by `jit.save` save format
# TODO(chenweihang): [Why not support `io.save_infernece_model` save format here]
# Load state dict by `jit.save/io.save_inference_model` save format
# NOTE(chenweihang): [ Compatibility of save_inference_model save format ]
# The model saved by `save_inference_model` does not completely correspond to
# the information required by the `state_dict` under the dygraph.
# Although we reluctantly restore the `state_dict` in some scenarios,
# this may not be complete and there are some limitations, so this function
# will be considered later. The limitations include:
# 1. `save_inference_model` not save structured name, we need to remind
# the user to configure the `use_structured_name` argument when `set_dict`,
# but this argument is currently not public
# 2. if `save_inference_model` save all persistable variables in a single file,
# user need to give the variable name list to load `state_dict`
# `save_inference_model` not save structured name, we need to remind
# the user to configure the `use_structured_name` argument when `set_state_dict`
# NOTE(chenweihang): `jit.save` doesn't save optimizer state

# 1. check model path
if not os.path.isdir(model_prefix):
raise ValueError("Model saved directory '%s' is not exists." %
model_prefix)
# 2. load `__variables.info__`
var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME)
if not os.path.exists(var_info_path):
raise RuntimeError(
"No target can be loaded. Now only supports loading `state_dict` from "
"the result saved by `imperative.save` and `imperative.jit.save`."
)
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f)
# 3. load `__variables__`
# TODO(chenweihang): now only supports loading from default save format:
# - all persistable vars saved in one file named `__variables__`
# for other case, we may need to modify the arguments of this API
var_file_path = os.path.join(model_prefix, VARIABLE_FILENAME)
if not os.path.exists(var_file_path):
raise RuntimeError(
"The parameter file to be loaded was not found. "
"Now only supports loading from the default save format, "
"and does not support custom params_filename and "
"save parameters separately.")
# 4. load all persistable vars
load_var_list = []
for name in sorted(extra_var_info):
var = _varbase_creator(name=name, persistable=True)
load_var_list.append(var)
_dygraph_tracer().trace_op(
type='load_combine',
inputs={},
outputs={'Out': load_var_list},
attrs={'file_path': var_file_path})
# 5. construct state_dict
para_dict = dict()
for var in load_var_list:
structured_name = extra_var_info[var.name].get('structured_name',
None)
if structured_name is None:
raise RuntimeError(
"Cannot find saved variable (%s)'s structured name in saved model.",
var.name)
para_dict[structured_name] = var.numpy()
# NOTE: `jit.save` doesn't save optimizer state

# 2. load program desc & construct _ProgramHolder
programs = _construct_program_holders(model_path,
configs.model_filename)

# 3. load layer parameters & buffers
# NOTE: using fluid.dygraph.guard() here will cause import error in py2
with guard():
persistable_var_dict = _construct_params_and_buffers(
model_prefix,
programs,
configs.separate_params,
configs.params_filename,
append_suffix=False)

# 4. construct state_dict
para_dict = dict()
for var_name in persistable_var_dict:
para_dict[var_name] = persistable_var_dict[var_name].numpy()
else:
# Load state dict by `save_dygraph` save format
para_dict = {}
Expand All @@ -213,7 +241,7 @@ def load_dygraph(model_path, keep_name_table=False):
para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')

if not keep_name_table and "StructuredToParameterName@@" in para_dict:
if not configs.keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"]

if os.path.exists(opti_file_path):
Expand Down
18 changes: 16 additions & 2 deletions python/paddle/fluid/dygraph/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,15 @@ def _load_persistable_vars(model_path,
return load_var_dict


# NOTE(chenweihang): to adapt paddle.load to get state_dict
def _remove_varname_suffix(var_dict, program_holder):
no_suffix_var_dict = dict()
for var_name in var_dict:
no_suffix_name = program_holder._suffix_varname_dict[var_name]
no_suffix_var_dict[no_suffix_name] = var_dict[var_name]
return no_suffix_var_dict


def _construct_program_holders(model_path, model_filename=None):
# make sure the path has been checked
program_holder_dict = dict()
Expand Down Expand Up @@ -517,7 +526,8 @@ def _construct_program_holders(model_path, model_filename=None):
def _construct_params_and_buffers(model_path,
programs,
separate_params=False,
params_filename=None):
params_filename=None,
append_suffix=True):
var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME)
if os.path.exists(var_info_path):
var_dict = _load_persistable_vars(model_path, var_info_path,
Expand All @@ -526,6 +536,10 @@ def _construct_params_and_buffers(model_path,
else:
var_dict = _load_persistable_vars_by_program(
model_path, programs['forward'], params_filename)

if not append_suffix:
var_dict = _remove_varname_suffix(var_dict, programs['forward'])

return var_dict


Expand Down Expand Up @@ -685,7 +699,7 @@ def _construct(model_path, configs=None):
# 1. load program desc & construct _ProgramHolder
programs = _construct_program_holders(model_path, model_filename)

# 2. load layer parameters & parameter attributes
# 2. load layer parameters & buffers
persistable_vars = _construct_params_and_buffers(
model_path, programs, separate_params, params_filename)

Expand Down
Loading

0 comments on commit 2986184

Please sign in to comment.