Skip to content

Commit

Permalink
Change jit.save/load configs to config & update code examples (#27056)
Browse files Browse the repository at this point in the history
* change configs to config & update examples

* fix deprecate decorator conflict
  • Loading branch information
chenwhql committed Sep 7, 2020
1 parent 0443b48 commit c1a8868
Show file tree
Hide file tree
Showing 2 changed files with 298 additions and 275 deletions.
16 changes: 9 additions & 7 deletions python/paddle/fluid/dygraph/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import warnings
from .. import core
from .base import guard
from paddle.fluid.dygraph.jit import SaveLoadConfig
from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers

__all__ = [
Expand All @@ -42,17 +42,17 @@ def __warn_and_build_configs__(keep_name_table):
warnings.warn(
"The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.",
DeprecationWarning)
configs = SaveLoadConfig()
configs.keep_name_table = keep_name_table
return configs
config = SaveLoadConfig()
config.keep_name_table = keep_name_table
return config

# deal with arg `keep_name_table`
if len(args) > 1 and isinstance(args[1], bool):
args = list(args)
args[1] = __warn_and_build_configs__(args[1])
# deal with kwargs
elif 'keep_name_table' in kwargs:
kwargs['configs'] = __warn_and_build_configs__(kwargs[
kwargs['config'] = __warn_and_build_configs__(kwargs[
'keep_name_table'])
kwargs.pop('keep_name_table')
else:
Expand Down Expand Up @@ -135,8 +135,9 @@ def save_dygraph(state_dict, model_path):
# TODO(qingqing01): remove dygraph_only to support loading static model.
# maybe need to unify the loading interface after 2.0 API is ready.
# @dygraph_only
@deprecate_save_load_configs
@deprecate_keep_name_table
def load_dygraph(model_path, configs=None):
def load_dygraph(model_path, config=None):
'''
:api_attr: imperative
Expand All @@ -151,7 +152,7 @@ def load_dygraph(model_path, configs=None):
Args:
model_path(str) : The file prefix store the state_dict.
(The path should Not contain suffix '.pdparams')
configs (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig`
config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig`
object that specifies additional configuration options, these options
are for compatibility with ``jit.save/io.save_inference_model`` formats.
Default None.
Expand Down Expand Up @@ -195,6 +196,7 @@ def load_dygraph(model_path, configs=None):
opti_file_path = model_prefix + ".pdopt"

# deal with argument `configs`
configs = config
if configs is None:
configs = SaveLoadConfig()

Expand Down
Loading

0 comments on commit c1a8868

Please sign in to comment.