-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support dynamic export for dynabert #3549
Support dynamic export for dynabert #3549
Conversation
def load_parameters(dynabert_model, ori_state_dict): | ||
dynabert_state_dict = dynabert_model.state_dict() | ||
for key in ori_state_dict.keys(): | ||
dynabert_key = key.replace(".fn", "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的通用性如何? 是否要做一些限制了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的通用性如何? 是否要做一些限制了?
这里不太明白为什么对fn的参数做一个替换
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ofa model的参数被改写过,模型的state_dict的key中都有个'.fn',所以在set参数之前,需要把'.fn'去掉
Codecov Report
@@ Coverage Diff @@
## develop #3549 +/- ##
========================================
Coverage 36.33% 36.33%
========================================
Files 419 419
Lines 59226 59221 -5
========================================
- Hits 21520 21519 -1
+ Misses 37706 37702 -4
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
if "ptq" in args.strategy: | ||
self.args.input_filename_prefix = "pruned_model" | ||
if "ptq" in args.strategy or "qat" in args.strategy: | ||
self.args.input_filename_prefix = "model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里将前缀改变的原因是啥了? 看起来其他的model zoo模型也要改这个前缀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
export_model
API 保存的模型前缀是model
,不支持自定义前缀名称
def load_parameters(dynabert_model, ori_state_dict): | ||
dynabert_state_dict = dynabert_model.state_dict() | ||
for key in ori_state_dict.keys(): | ||
dynabert_key = key.replace(".fn", "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的通用性如何? 是否要做一些限制了?
这里不太明白为什么对fn的参数做一个替换
if len(dynabert_shape) == 2: | ||
dynabert_state_dict[dynabert_key] = ori_state_dict[key][: dynabert_shape[0], : dynabert_shape[1]] | ||
else: | ||
dynabert_state_dict[dynabert_key] = ori_state_dict[key][: dynabert_shape[0]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议这里对shape不为1,2的抛一个报错,防止用户对其他模型使用有问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢提醒,已经增加了对异常处理的逻辑
…nto support-dynamic-dynabert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…nto support-dynamic-dynabert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Description
Support dynamic export for dynabert