-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Embedding] Add embedding training #9508
Merged
Merged
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
c141720
add Qwen2SentenceEmbedding
DrownFish19 6e9efb2
update modeling
DrownFish19 89d23e6
Merge remote-tracking branch 'paddlenlp/develop' into dev_20241121_ad…
DrownFish19 4d974e5
add embedding trainer
DesmonDay f408f2e
embedding
DrownFish19 a3da81b
fix
DrownFish19 18405ed
Merge remote-tracking branch 'paddlenlp-daisiming/add_embedding_train…
DrownFish19 f8e877b
Merge remote-tracking branch 'paddlenlp/develop' into dev_20241121_ad…
DrownFish19 e6394ad
update
DrownFish19 ba2c286
support cross device
DesmonDay a71783b
Merge branch 'dev_20241121_add_qwen2_embedding' of https://github.com…
DesmonDay 759d832
update trainer
DesmonDay b92df93
add loss
DesmonDay d3d5a7f
delete unused code
DesmonDay 88ba45e
delete unused code
DesmonDay b5c08aa
optimize code
DesmonDay d815fce
update
DesmonDay 0a618b0
update
DesmonDay 344d4f0
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay 39d324d
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import paddle | ||
|
||
|
||
class SimpleContrastiveLoss(paddle.nn.Layer): | ||
def __init__(self, embedding_temperature: float = 0.02): | ||
super().__init__() | ||
self.embedding_temperature = embedding_temperature | ||
self.cross_entropy = paddle.nn.CrossEntropyLoss(reduction="mean") | ||
|
||
def forward(self, q_reps, p_reps): | ||
scores = paddle.matmul(q_reps, p_reps.transpose([1, 0])) | ||
scores = scores / self.embedding_temperature | ||
|
||
group_size = p_reps.shape[0] // q_reps.shape[0] | ||
batch_size = q_reps.shape[0] | ||
|
||
target = paddle.arange(batch_size, dtype="int64") | ||
target = target * group_size | ||
|
||
loss = self.cross_entropy(scores, target) | ||
return loss | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from contextlib import nullcontext | ||
|
||
import paddle | ||
from paddle.base import core | ||
from paddle.distributed import fleet | ||
|
||
from paddlenlp.trainer import Trainer | ||
from paddlenlp.transformers.contrastive_loss import SimpleContrastiveLoss | ||
|
||
__all__ = ["EmbeddingTrainer"] | ||
|
||
|
||
class EmbeddingTrainer(Trainer): | ||
def __init__(self, model_args, use_gradient_cache=False, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
self.model_args = model_args | ||
self.use_gradient_cache = use_gradient_cache | ||
self.accum_data = [] | ||
self.accum_freq = 0 | ||
self.accum_q_features = [] | ||
self.accum_p_features = [] | ||
self.accum_rng_states = {} | ||
self.accum_rng_states["cpu"] = [] | ||
self.accum_rng_states["cuda"] = [] | ||
self.accum_rng_states["hybrid"] = [] | ||
self.loss_fn = SimpleContrastiveLoss(self.model_args.embedding_temperature) | ||
|
||
def clear_memory(self): | ||
self.accum_q_features.clear() | ||
self.accum_p_features.clear() | ||
paddle.device.cuda.empty_cache() | ||
|
||
def clear_state(self): | ||
self.accum_data.clear() | ||
self.accum_rng_states["cpu"].clear() | ||
self.accum_rng_states["cuda"].clear() | ||
self.accum_rng_states["hybrid"].clear() | ||
self.accum_freq = 0 | ||
|
||
@paddle.no_grad() | ||
def forward_no_grad(self, model, inputs): | ||
# Step1: graph-less forward | ||
self.accum_data.append(inputs) | ||
inputs = self._prepare_inputs(inputs) | ||
with self.autocast_smart_context_manager(): | ||
# collect rand states | ||
self.accum_rng_states["cpu"].append(paddle.framework.core.default_cpu_generator().get_state()) | ||
self.accum_rng_states["cuda"].append(paddle.get_rng_state()) | ||
if self.args.use_hybrid_parallel: | ||
self.accum_rng_states["hybrid"].append( | ||
fleet.meta_parallel.get_rng_state_tracker().get_states_tracker() | ||
) | ||
|
||
query_reps, passage_reps = model(**inputs, return_encode=True) | ||
|
||
self.accum_q_features.append(query_reps) | ||
self.accum_p_features.append(passage_reps) | ||
|
||
self.accum_freq += 1 | ||
|
||
def get_current_rng_state(self): | ||
return { | ||
"cpu": [paddle.framework.core.default_cpu_generator().get_state()], | ||
"cuda": [paddle.get_rng_state()], | ||
"hybrid": [fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()], | ||
} | ||
|
||
def reset_rng_state(self, states, index=0): | ||
# set random states | ||
if len(states) != 3: | ||
raise ValueError("The length of state should be 3") | ||
cpu_state = states["cpu"][index] | ||
cuda_state = states["cuda"][index] | ||
hybrid_state = states["hybrid"][index] | ||
paddle.framework.core.default_cpu_generator().set_state(cpu_state) | ||
# TODO(daisiming): support xpu and other custom devices. | ||
if core.is_compiled_with_cuda(): | ||
for j in range(core.get_cuda_device_count()): | ||
core.default_cuda_generator(j).set_state(cuda_state[j]) | ||
if self.args.use_hybrid_parallel: | ||
fleet.meta_parallel.get_rng_state_tracker().set_states_tracker(hybrid_state) | ||
|
||
def accum_forward_backward(self, model): | ||
# Step2: representation gradient computation and caching | ||
for i in range(len(self.accum_q_features)): | ||
self.accum_q_features[i].stop_gradient = False | ||
q_reps = paddle.concat(self.accum_q_features, axis=0) | ||
for i in range(len(self.accum_p_features)): | ||
self.accum_p_features[i].stop_gradient = False | ||
p_reps = paddle.concat(self.accum_p_features, axis=0) | ||
|
||
loss = self.loss_fn(q_reps, p_reps) | ||
if self.do_grad_scaling: | ||
self.scaler.scale(loss).backward() | ||
else: | ||
loss.backward() | ||
# get represetation gradient cache | ||
accum_q_grads = [q.grad for q in self.accum_q_features] | ||
accum_p_grads = [p.grad for p in self.accum_p_features] | ||
del q_reps, p_reps | ||
|
||
# clear trash memory | ||
self.clear_memory() | ||
|
||
current_rng_state = self.get_current_rng_state() | ||
# Step3: sub-batch gradient accumulation | ||
for i in range(self.accum_freq): | ||
inputs = self.accum_data[i] | ||
inputs = self._prepare_inputs(inputs) | ||
|
||
sync_context = model.no_sync() if i != self.accum_freq - 1 and hasattr(model, "no_sync") else nullcontext() | ||
with sync_context: | ||
self.reset_rng_state(self.accum_rng_states, index=i) | ||
|
||
with self.autocast_smart_context_manager(): | ||
query_reps, passage_reps = model(**inputs, return_encode=True) | ||
|
||
_loss = paddle.dot(query_reps.flatten(), accum_q_grads[i].flatten()) + paddle.dot( | ||
passage_reps.flatten(), accum_p_grads[i].flatten() | ||
) | ||
_loss.backward() | ||
|
||
self.reset_rng_state(current_rng_state) | ||
self.clear_state() | ||
return loss.detach() | ||
|
||
def training_step( | ||
self, | ||
model, | ||
inputs, | ||
step_control=0, | ||
): | ||
if self.args.pipeline_parallel_degree > 1: | ||
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.") | ||
|
||
if self.args.gradient_accumulation_steps == 1 or not self.use_gradient_cache: | ||
return super().training_step(model, inputs) | ||
else: | ||
self.forward_no_grad(model, inputs) | ||
|
||
# if (step_control + 1) % self.args.gradient_accumulation_steps is not zero, move on to next batch. | ||
if (step_control + 1) % self.args.gradient_accumulation_steps != 0: | ||
return 0.0 | ||
|
||
loss = self.accum_forward_backward(model) | ||
return loss | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块得看看有无更好的方法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里要兼容,判断 self.training_step 有没有 step_control 参数