Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polish jit.save/load design & remove paddle.SaveLoadConfig #27623

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@
from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS
from .framework import SaveLoadConfig #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS

from .framework import NoamDecay #DEFINE_ALIAS
Expand Down
20 changes: 11 additions & 9 deletions python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX

paddle.enable_static()

Expand Down Expand Up @@ -231,10 +232,11 @@ def test_qat_save(self):
before_save = lenet(test_img)

# save inference quantized model
path = "./mnist_infer_model"
path = "./qat_infer_model/lenet"
save_dir = "./qat_infer_model"
paddle.jit.save(
layer=lenet,
model_path=path,
path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
Expand All @@ -245,12 +247,12 @@ def test_qat_save(self):
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
[inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=path,
executor=exe,
model_filename="__model__",
params_filename="__variables__"))
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
dirname=save_dir,
executor=exe,
model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX)
after_save, = exe.run(inference_program,
feed={feed_target_names[0]: test_data},
fetch_list=fetch_targets)
Expand Down Expand Up @@ -339,7 +341,7 @@ def _build_static_lenet(main, startup, is_test=False, seed=1000):

paddle.jit.save(
layer=lenet,
model_path="./dynamic_mnist",
path="./dynamic_mnist/model",
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX

paddle.enable_static()

Expand Down Expand Up @@ -231,10 +232,11 @@ def test_qat_save(self):
before_save = lenet(test_img)

# save inference quantized model
path = "./mnist_infer_model"
path = "./qat_infer_model/mnist"
save_dir = "./qat_infer_model"
paddle.jit.save(
layer=lenet,
model_path=path,
path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
Expand All @@ -245,12 +247,12 @@ def test_qat_save(self):
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
[inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=path,
executor=exe,
model_filename="__model__",
params_filename="__variables__"))
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
dirname=save_dir,
executor=exe,
model_filename="mnist" + INFER_MODEL_SUFFIX,
params_filename="mnist" + INFER_PARAMS_SUFFIX)
after_save, = exe.run(inference_program,
feed={feed_target_names[0]: test_data},
fetch_list=fetch_targets)
Expand Down Expand Up @@ -339,7 +341,7 @@ def _build_static_lenet(main, startup, is_test=False, seed=1000):

paddle.jit.save(
layer=lenet,
model_path="./dynamic_mnist",
path="./dynamic_mnist/model",
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
Expand Down
73 changes: 31 additions & 42 deletions python/paddle/fluid/dygraph/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,44 +24,32 @@
import warnings
from .. import core
from .base import guard
from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers, EXTRA_VAR_INFO_FILENAME
from paddle.fluid.dygraph.jit import _SaveLoadConfig
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers

__all__ = [
'save_dygraph',
'load_dygraph',
]


# NOTE(chenweihang): deprecate load_dygraph's argument keep_name_table,
# ensure compatibility when user still use keep_name_table argument
def deprecate_keep_name_table(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
def __warn_and_build_configs__(keep_name_table):
warnings.warn(
"The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.",
DeprecationWarning)
config = SaveLoadConfig()
config.keep_name_table = keep_name_table
return config

# deal with arg `keep_name_table`
if len(args) > 1 and isinstance(args[1], bool):
args = list(args)
args[1] = __warn_and_build_configs__(args[1])
# deal with kwargs
elif 'keep_name_table' in kwargs:
kwargs['config'] = __warn_and_build_configs__(kwargs[
'keep_name_table'])
kwargs.pop('keep_name_table')
else:
# do nothing
pass
def _parse_load_config(configs):
supported_configs = ['model_filename', 'params_filename', 'keep_name_table']

# input check
for key in configs:
if key not in supported_configs:
raise ValueError(
"The additional config (%s) of `paddle.fluid.load_dygraph` is not supported."
% (key))

return func(*args, **kwargs)
# construct inner config
inner_config = _SaveLoadConfig()
inner_config.model_filename = configs.get('model_filename', None)
inner_config.params_filename = configs.get('params_filename', None)
inner_config.keep_name_table = configs.get('keep_name_table', None)

return wrapper
return inner_config


@dygraph_only
Expand Down Expand Up @@ -132,12 +120,12 @@ def save_dygraph(state_dict, model_path):
pickle.dump(model_dict, f, protocol=2)


# NOTE(chenweihang): load_dygraph will deprecated in future, we don't
# support new loading features for it
# TODO(qingqing01): remove dygraph_only to support loading static model.
# maybe need to unify the loading interface after 2.0 API is ready.
# @dygraph_only
@deprecate_save_load_configs
@deprecate_keep_name_table
def load_dygraph(model_path, config=None):
def load_dygraph(model_path, **configs):
'''
:api_attr: imperative

Expand All @@ -152,10 +140,13 @@ def load_dygraph(model_path, config=None):
Args:
model_path(str) : The file prefix store the state_dict.
(The path should Not contain suffix '.pdparams')
config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig`
object that specifies additional configuration options, these options
are for compatibility with ``jit.save/io.save_inference_model`` formats.
Default None.
**configs (dict, optional): other save configuration options for compatibility. We do not
recommend using these configurations, if not necessary, DO NOT use them. Default None.
The following options are currently supported:
(1) model_filename (string): The inference model file name of the paddle 1.x ``save_inference_model``
save format. Default file name is :code:`__model__` .
(2) params_filename (string): The persistable variables file name of the paddle 1.x ``save_inference_model``
save format. No default file name, save variables separately by default.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't recommend users to use keep_name_table, only keep it for debuging, so don't write its doc.


Returns:
state_dict(dict) : the dict store the state_dict
Expand Down Expand Up @@ -196,8 +187,7 @@ def load_dygraph(model_path, config=None):
opti_file_path = model_prefix + ".pdopt"

# deal with argument `config`
if config is None:
config = SaveLoadConfig()
config = _parse_load_config(configs)

if os.path.exists(params_file_path) or os.path.exists(opti_file_path):
# Load state dict by `save_dygraph` save format
Expand Down Expand Up @@ -246,7 +236,6 @@ def load_dygraph(model_path, config=None):
persistable_var_dict = _construct_params_and_buffers(
model_prefix,
programs,
config.separate_params,
config.params_filename,
append_suffix=False)

Expand All @@ -255,9 +244,9 @@ def load_dygraph(model_path, config=None):
for var_name in persistable_var_dict:
para_dict[var_name] = persistable_var_dict[var_name].numpy()

# if __variables.info__ exists, we can recover structured_name
var_info_path = os.path.join(model_prefix,
EXTRA_VAR_INFO_FILENAME)
# if *.info exists, we can recover structured_name
var_info_filename = str(config.params_filename) + ".info"
var_info_path = os.path.join(model_prefix, var_info_filename)
if os.path.exists(var_info_path):
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f)
Expand Down
59 changes: 22 additions & 37 deletions python/paddle/fluid/dygraph/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@

__all__ = ['TranslatedLayer']

VARIABLE_FILENAME = "__variables__"
EXTRA_VAR_INFO_FILENAME = "__variables.info__"
INFER_MODEL_SUFFIX = ".pdmodel"
INFER_PARAMS_SUFFIX = ".pdiparams"
INFER_PARAMS_INFO_SUFFIX = ".pdiparams.info"

LOADED_VAR_SUFFIX = "load"
PARAMETER_NAME_PREFIX = "param"
BUFFER_NAME_PREFIX = "buffer"
Expand Down Expand Up @@ -424,11 +426,8 @@ def _load_persistable_vars_by_program(model_path,
return load_var_dict


def _load_persistable_vars(model_path,
var_info_path,
program_holder,
separate_params=False,
params_filename=None):
def _load_persistable_vars(model_path, var_info_path, program_holder,
params_filename):
# 1. load extra var info
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f)
Expand Down Expand Up @@ -464,33 +463,22 @@ def _load_persistable_vars(model_path,
new_var = framework._varbase_creator(
name=new_name, persistable=True)

# load separate vars
if separate_params is True:
framework._dygraph_tracer().trace_op(
type='load',
inputs={},
outputs={'Out': new_var},
attrs={'file_path': os.path.join(model_path, name)})

new_var.stop_gradient = extra_var_info[name]['stop_gradient']
load_var_dict[new_name] = new_var
load_var_list.append(new_var)

# 3. load all vars
if separate_params is False:
if params_filename is not None:
var_file_path = os.path.join(model_path, params_filename)
else:
var_file_path = os.path.join(model_path, VARIABLE_FILENAME)
if not os.path.exists(var_file_path):
if len(extra_var_info) != 0:
raise ValueError("The model to be loaded is incomplete.")
else:
framework._dygraph_tracer().trace_op(
type='load_combine',
inputs={},
outputs={'Out': load_var_list},
attrs={'file_path': var_file_path})
assert params_filename is not None, "params_filename should not be None."
var_file_path = os.path.join(model_path, params_filename)
if not os.path.exists(var_file_path):
if len(extra_var_info) != 0:
raise ValueError("The model to be loaded is incomplete.")
else:
framework._dygraph_tracer().trace_op(
type='load_combine',
inputs={},
outputs={'Out': load_var_list},
attrs={'file_path': var_file_path})

return load_var_dict

Expand Down Expand Up @@ -532,14 +520,13 @@ def _construct_program_holders(model_path, model_filename=None):

def _construct_params_and_buffers(model_path,
programs,
separate_params=False,
params_filename=None,
append_suffix=True):
var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME)
var_info_filename = str(params_filename) + ".info"
var_info_path = os.path.join(model_path, var_info_filename)
if os.path.exists(var_info_path):
var_dict = _load_persistable_vars(model_path, var_info_path,
programs['forward'], separate_params,
params_filename)
programs['forward'], params_filename)
else:
var_dict = _load_persistable_vars_by_program(
model_path, programs['forward'], params_filename)
Expand Down Expand Up @@ -700,18 +687,16 @@ def _construct(model_path, configs=None):
raise ValueError("There is no directory named '%s'" % model_path)
model_filename = None
params_filename = None
separate_params = False
if configs is not None:
model_filename = configs.model_filename
params_filename = configs.params_filename
separate_params = configs.separate_params

# 1. load program desc & construct _ProgramHolder
programs = _construct_program_holders(model_path, model_filename)

# 2. load layer parameters & buffers
persistable_vars = _construct_params_and_buffers(
model_path, programs, separate_params, params_filename)
persistable_vars = _construct_params_and_buffers(model_path, programs,
params_filename)

# 3. construct TranslatedLayer object
translated_layer = TranslatedLayer(programs, persistable_vars)
Expand Down
Loading