-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Embedding] Add embedding training #9508
Conversation
Thanks for your contribution! |
paddlenlp/trainer/trainer.py
Outdated
@@ -1093,9 +1093,9 @@ def _inner_training_loop( | |||
if is_no_sync: | |||
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example. | |||
with model.no_sync(): | |||
tr_loss_step = self.training_step(model, inputs) | |||
tr_loss_step = self.training_step(model, inputs, step_control=step_control) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块得看看有无更好的方法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里要兼容,判断 self.training_step 有没有 step_control 参数
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9508 +/- ##
===========================================
- Coverage 52.77% 52.71% -0.07%
===========================================
Files 709 710 +1
Lines 111172 111417 +245
===========================================
+ Hits 58674 58733 +59
- Misses 52498 52684 +186 ☔ View full report in Codecov by Sentry. |
…er' into dev_20241121_add_qwen2_embedding
…d_qwen2_embedding
20f56e6
to
ba2c286
Compare
7964ad3
to
d815fce
Compare
llm/config/qwen/emb_argument.json
Outdated
"max_query_len": 1024, | ||
"max_passage_len": 2048, | ||
"group_size": 4, | ||
"bp16": true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bf16
model_config.embedding_negatives_cross_device = embedding_args.embedding_negatives_cross_device | ||
logger.info(f"Final model config: {model_config}") | ||
|
||
model_class = Qwen2SentenceEmbedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后面改掉吧 可以搞一个 Auto的class
trainable_parameters = [p for p in model.parameters() if not p.stop_gradient] | ||
trainer.set_optimizer_grouped_parameters(trainable_parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 @lugimzzz 之前是为啥加来着?
paddlenlp/trainer/trainer.py
Outdated
@@ -1093,9 +1093,9 @@ def _inner_training_loop( | |||
if is_no_sync: | |||
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example. | |||
with model.no_sync(): | |||
tr_loss_step = self.training_step(model, inputs) | |||
tr_loss_step = self.training_step(model, inputs, step_control=step_control) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里要兼容,判断 self.training_step 有没有 step_control 参数
1328048
to
63ba9d0
Compare
63ba9d0
to
0a618b0
Compare
…nto add_embedding_trainer
# Detecting last checkpoint. | ||
last_checkpoint = None | ||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: | ||
last_checkpoint = get_last_checkpoint(training_args.output_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用方法,数据集格式,加一个readme吧
…nto add_embedding_trainer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
先你补一下 readme,样例数据吧。 热启恢复问题后续分析。
PR types
New features
PR changes
Others
Description
Support embedding training.