-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support load state dict form inference model
format save result
#26718
Support load state dict form inference model
format save result
#26718
Conversation
Thanks for your contribution! |
…' of https://github.com/chenwhql/Paddle into saveload/support_load_inference_model_format_state_dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
def load_dygraph(model_path, keep_name_table=False): | ||
# @dygraph_only | ||
@deprecate_keep_name_table | ||
def load_dygraph(model_path, configs=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
configs -> config
这里建议用单数,因为SaveLoadConfig是单数形式
config = paddle.SaveLoadConfig()
可以后续统一再改下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,在下一个PR中修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里能否将SaveLoadConfig改为SaveLoadConfigs,这样兼容性更好处理一些,因为还涉及到jit.save和jit.load接口中的configs参数
paddle.save(state_dict, "paddle_dy") | ||
|
||
configs = paddle.SaveLoadConfig() | ||
configs.keep_name_table = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep_name_table默认设置为True的话,会有什么问题吗?
不是很了解背景
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
也没有什么风险,只是用户载入的state_dict里面会多一些额外信息,这些信息对用户一般没有帮助,保留这个选项主要是为了兼容,怕有同学用利用了旧版实现里面的这个信息
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
或许我们可以先将这个参数作为内部参数,不向用户公开
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…ddlePaddle#26718) * support load infer model format state dict * add unittests * remove keep name table * recolve circle inport * fix compatible problem * recover unittest * polish doc and comment
* Update set_dict method name & add aliases (#26700) * update set_dict method name & add aliases * fix var name error * fix alias formats * use set_state_dict in unittest * add decorator solve compatible problem * polish decorator * replace layer set_state_dict by patched method * remove import monkey path layer * fix import function error * add unittest for coverage * Support load state dict form `inference model` format save result (#26718) * support load infer model format state dict * add unittests * remove keep name table * recolve circle inport * fix compatible problem * recover unittest * polish doc and comment * Change jit.save/load configs to config & update code examples (#27056) * change configs to config & update examples * fix deprecate decorator conflict
PR types
New features
PR changes
APIs
Describe
【兼容升级】This PR supprot loading
state_dict
fromsave_inference_model
save result.This is an important feature, and we often see related problems in issues. After we switched from static graph mode to dynamic graph mode, many users rewritten the dynamic graph model code and needed to load from the result of
save_inferenece_model
in original static graph model, such as: #26519API changes:
paddle.load(model_path, keep_name_table=False)
->paddle.load(model_path, configs=None)
: add decoratordeprecate_keep_name_table
for compatibilitypaddle.jit.SaveLoadConfig
->paddle.SaveLoadConfig
: becausepaddle.load
also use it, move it to paddle module,SaveLoadConfig
is a class for two save format's compatibility, not only a jit related classpaddle.SaveLoadConfig.keep_name_table
: for compatible with oldpaddle.load(model_path, keep_name_table=False)
After this PR, the Paddle 2.0 Save/Load inferfaces are like follow:
This PR creates the
green line
in graph.Doc change
related PR: PaddlePaddle/docs#2531