Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dynamic export for dynabert #3549

Merged
merged 10 commits into from
Dec 29, 2022

Conversation

LiuChiachi
Copy link
Contributor

@LiuChiachi LiuChiachi commented Oct 25, 2022

PR types

New features

PR changes

APIs

Description

Support dynamic export for dynabert

PYTHONPATH=/liujiaqi/PaddleNLP:/liujiaqi/PaddleSlim python compress_seq_cls.py    \
    --dataset   "clue cluewsc2020"     \
    --model_name_or_path ernie-3.0-tiny-nano-v2-zh   \
    --per_device_train_batch_size 32   \
    --output_dir ./test  \
    --per_device_eval_batch_size 32    \
    --num_train_epochs 5   \
    --width_mult_list 2/3   \
    --batch_size_list 4   \
    --algo_list 'abs_max'   \
    --strategy 'dynabert ptq embeddings'  \
    --onnx_format False

def load_parameters(dynabert_model, ori_state_dict):
dynabert_state_dict = dynabert_model.state_dict()
for key in ori_state_dict.keys():
dynabert_key = key.replace(".fn", "")
Copy link
Collaborator

@wawltor wawltor Oct 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的通用性如何? 是否要做一些限制了?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的通用性如何? 是否要做一些限制了?

这里不太明白为什么对fn的参数做一个替换

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ofa model的参数被改写过,模型的state_dict的key中都有个'.fn',所以在set参数之前,需要把'.fn'去掉

@codecov
Copy link

codecov bot commented Dec 19, 2022

Codecov Report

Merging #3549 (7229c0d) into develop (240f817) will increase coverage by 0.00%.
The diff coverage is 7.14%.

❗ Current head 7229c0d differs from pull request most recent head 434151e. Consider uploading reports for the commit 434151e to get more accurate results

@@           Coverage Diff            @@
##           develop    #3549   +/-   ##
========================================
  Coverage    36.33%   36.33%           
========================================
  Files          419      419           
  Lines        59226    59221    -5     
========================================
- Hits         21520    21519    -1     
+ Misses       37706    37702    -4     
Impacted Files Coverage Δ
paddlenlp/trainer/trainer_compress.py 9.01% <7.14%> (+0.30%) ⬆️
paddlenlp/transformers/roberta/modeling.py 89.85% <0.00%> (-0.37%) ⬇️
paddlenlp/trainer/trainer.py 11.24% <0.00%> (-0.10%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

if "ptq" in args.strategy:
self.args.input_filename_prefix = "pruned_model"
if "ptq" in args.strategy or "qat" in args.strategy:
self.args.input_filename_prefix = "model"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里将前缀改变的原因是啥了? 看起来其他的model zoo模型也要改这个前缀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

export_model API 保存的模型前缀是model,不支持自定义前缀名称

def load_parameters(dynabert_model, ori_state_dict):
dynabert_state_dict = dynabert_model.state_dict()
for key in ori_state_dict.keys():
dynabert_key = key.replace(".fn", "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的通用性如何? 是否要做一些限制了?

这里不太明白为什么对fn的参数做一个替换

if len(dynabert_shape) == 2:
dynabert_state_dict[dynabert_key] = ori_state_dict[key][: dynabert_shape[0], : dynabert_shape[1]]
else:
dynabert_state_dict[dynabert_key] = ori_state_dict[key][: dynabert_shape[0]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议这里对shape不为1,2的抛一个报错,防止用户对其他模型使用有问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢提醒,已经增加了对异常处理的逻辑

@LiuChiachi LiuChiachi requested a review from wawltor December 28, 2022 12:00
wawltor
wawltor previously approved these changes Dec 29, 2022
Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LiuChiachi LiuChiachi merged commit ff3c0c9 into PaddlePaddle:develop Dec 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants