Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Embedding] Add embedding training #9508

Merged
merged 20 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,9 +1093,9 @@
if is_no_sync:
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
tr_loss_step = self.training_step(model, inputs)
tr_loss_step = self.training_step(model, inputs, step_control=step_control)

Check warning on line 1096 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1096

Added line #L1096 was not covered by tests
Copy link
Contributor Author

@DesmonDay DesmonDay Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块得看看有无更好的方法

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里要兼容,判断 self.training_step 有没有 step_control 参数

else:
tr_loss_step = self.training_step(model, inputs)
tr_loss_step = self.training_step(model, inputs, step_control=step_control)

tr_loss += tr_loss_step

Expand Down Expand Up @@ -2267,7 +2267,9 @@
else:
return False

def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
def training_step(
self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]], step_control=0
) -> paddle.Tensor:
"""
Perform a training step on a batch of inputs.

Expand Down
36 changes: 36 additions & 0 deletions paddlenlp/transformers/contrastive_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle


class SimpleContrastiveLoss(paddle.nn.Layer):
def __init__(self, embedding_temperature: float = 0.02):
super().__init__()
self.embedding_temperature = embedding_temperature
self.cross_entropy = paddle.nn.CrossEntropyLoss(reduction="mean")

Check warning on line 23 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L21-L23

Added lines #L21 - L23 were not covered by tests

def forward(self, q_reps, p_reps):
scores = paddle.matmul(q_reps, p_reps.transpose([1, 0]))
scores = scores / self.embedding_temperature

Check warning on line 27 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L26-L27

Added lines #L26 - L27 were not covered by tests

group_size = p_reps.shape[0] // q_reps.shape[0]
batch_size = q_reps.shape[0]

Check warning on line 30 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L29-L30

Added lines #L29 - L30 were not covered by tests

target = paddle.arange(batch_size, dtype="int64")
target = target * group_size

Check warning on line 33 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L32-L33

Added lines #L32 - L33 were not covered by tests

loss = self.cross_entropy(scores, target)
return loss

Check warning on line 36 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L35-L36

Added lines #L35 - L36 were not covered by tests
1 change: 1 addition & 0 deletions paddlenlp/trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .dpo_criterion import DPOCriterion
from .dpo_trainer import DPOTrainer
from .embedding_trainer import EmbeddingTrainer
from .kto_criterion import KTOCriterion
from .kto_trainer import KTOTrainer
from .sft_trainer import *
Expand Down
161 changes: 161 additions & 0 deletions paddlenlp/trl/embedding_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext

import paddle
from paddle.base import core
from paddle.distributed import fleet

from paddlenlp.trainer import Trainer
from paddlenlp.transformers.contrastive_loss import SimpleContrastiveLoss

__all__ = ["EmbeddingTrainer"]


class EmbeddingTrainer(Trainer):
def __init__(self, model_args, use_gradient_cache=False, **kwargs):
super().__init__(**kwargs)

Check warning on line 29 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L29

Added line #L29 was not covered by tests

self.model_args = model_args
self.use_gradient_cache = use_gradient_cache
self.accum_data = []
self.accum_freq = 0
self.accum_q_features = []
self.accum_p_features = []
self.accum_rng_states = {}
self.accum_rng_states["cpu"] = []
self.accum_rng_states["cuda"] = []
self.accum_rng_states["hybrid"] = []
self.loss_fn = SimpleContrastiveLoss(self.model_args.embedding_temperature)

Check warning on line 41 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L31-L41

Added lines #L31 - L41 were not covered by tests

def clear_memory(self):
self.accum_q_features.clear()
self.accum_p_features.clear()
paddle.device.cuda.empty_cache()

Check warning on line 46 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L44-L46

Added lines #L44 - L46 were not covered by tests

def clear_state(self):
self.accum_data.clear()
self.accum_rng_states["cpu"].clear()
self.accum_rng_states["cuda"].clear()
self.accum_rng_states["hybrid"].clear()
self.accum_freq = 0

Check warning on line 53 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L49-L53

Added lines #L49 - L53 were not covered by tests

@paddle.no_grad()
def forward_no_grad(self, model, inputs):
# Step1: graph-less forward
self.accum_data.append(inputs)
inputs = self._prepare_inputs(inputs)
with self.autocast_smart_context_manager():

Check warning on line 60 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L58-L60

Added lines #L58 - L60 were not covered by tests
# collect rand states
self.accum_rng_states["cpu"].append(paddle.framework.core.default_cpu_generator().get_state())
self.accum_rng_states["cuda"].append(paddle.get_rng_state())
if self.args.use_hybrid_parallel:
self.accum_rng_states["hybrid"].append(

Check warning on line 65 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L62-L65

Added lines #L62 - L65 were not covered by tests
fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()
)

query_reps, passage_reps = model(**inputs, return_encode=True)

Check warning on line 69 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L69

Added line #L69 was not covered by tests

self.accum_q_features.append(query_reps)
self.accum_p_features.append(passage_reps)

Check warning on line 72 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L71-L72

Added lines #L71 - L72 were not covered by tests

self.accum_freq += 1

Check warning on line 74 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L74

Added line #L74 was not covered by tests

def get_current_rng_state(self):
return {

Check warning on line 77 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L77

Added line #L77 was not covered by tests
"cpu": [paddle.framework.core.default_cpu_generator().get_state()],
"cuda": [paddle.get_rng_state()],
"hybrid": [fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()],
}

def reset_rng_state(self, states, index=0):
# set random states
if len(states) != 3:
raise ValueError("The length of state should be 3")
cpu_state = states["cpu"][index]
cuda_state = states["cuda"][index]
hybrid_state = states["hybrid"][index]
paddle.framework.core.default_cpu_generator().set_state(cpu_state)

Check warning on line 90 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L85-L90

Added lines #L85 - L90 were not covered by tests
# TODO(daisiming): support xpu and other custom devices.
if core.is_compiled_with_cuda():
for j in range(core.get_cuda_device_count()):
core.default_cuda_generator(j).set_state(cuda_state[j])
if self.args.use_hybrid_parallel:
fleet.meta_parallel.get_rng_state_tracker().set_states_tracker(hybrid_state)

Check warning on line 96 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L92-L96

Added lines #L92 - L96 were not covered by tests

def accum_forward_backward(self, model):
# Step2: representation gradient computation and caching
for i in range(len(self.accum_q_features)):
self.accum_q_features[i].stop_gradient = False
q_reps = paddle.concat(self.accum_q_features, axis=0)
for i in range(len(self.accum_p_features)):
self.accum_p_features[i].stop_gradient = False
p_reps = paddle.concat(self.accum_p_features, axis=0)

Check warning on line 105 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L100-L105

Added lines #L100 - L105 were not covered by tests

loss = self.loss_fn(q_reps, p_reps)
if self.do_grad_scaling:
self.scaler.scale(loss).backward()

Check warning on line 109 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L107-L109

Added lines #L107 - L109 were not covered by tests
else:
loss.backward()

Check warning on line 111 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L111

Added line #L111 was not covered by tests
# get represetation gradient cache
accum_q_grads = [q.grad for q in self.accum_q_features]
accum_p_grads = [p.grad for p in self.accum_p_features]
del q_reps, p_reps

Check warning on line 115 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L113-L115

Added lines #L113 - L115 were not covered by tests

# clear trash memory
self.clear_memory()

Check warning on line 118 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L118

Added line #L118 was not covered by tests

current_rng_state = self.get_current_rng_state()

Check warning on line 120 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L120

Added line #L120 was not covered by tests
# Step3: sub-batch gradient accumulation
for i in range(self.accum_freq):
inputs = self.accum_data[i]
inputs = self._prepare_inputs(inputs)

Check warning on line 124 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L122-L124

Added lines #L122 - L124 were not covered by tests

sync_context = model.no_sync() if i != self.accum_freq - 1 and hasattr(model, "no_sync") else nullcontext()
with sync_context:
self.reset_rng_state(self.accum_rng_states, index=i)

Check warning on line 128 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L126-L128

Added lines #L126 - L128 were not covered by tests

with self.autocast_smart_context_manager():
query_reps, passage_reps = model(**inputs, return_encode=True)

Check warning on line 131 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L130-L131

Added lines #L130 - L131 were not covered by tests

_loss = paddle.dot(query_reps.flatten(), accum_q_grads[i].flatten()) + paddle.dot(

Check warning on line 133 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L133

Added line #L133 was not covered by tests
passage_reps.flatten(), accum_p_grads[i].flatten()
)
_loss.backward()

Check warning on line 136 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L136

Added line #L136 was not covered by tests

self.reset_rng_state(current_rng_state)
self.clear_state()
return loss.detach()

Check warning on line 140 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L138-L140

Added lines #L138 - L140 were not covered by tests

def training_step(
self,
model,
inputs,
step_control=0,
):
if self.args.pipeline_parallel_degree > 1:
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.")

Check warning on line 149 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L148-L149

Added lines #L148 - L149 were not covered by tests

if self.args.gradient_accumulation_steps == 1 or not self.use_gradient_cache:
return super().training_step(model, inputs)

Check warning on line 152 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L151-L152

Added lines #L151 - L152 were not covered by tests
else:
self.forward_no_grad(model, inputs)

Check warning on line 154 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L154

Added line #L154 was not covered by tests

# if (step_control + 1) % self.args.gradient_accumulation_steps is not zero, move on to next batch.
if (step_control + 1) % self.args.gradient_accumulation_steps != 0:
return 0.0

Check warning on line 158 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L157-L158

Added lines #L157 - L158 were not covered by tests

loss = self.accum_forward_backward(model)
return loss

Check warning on line 161 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L160-L161

Added lines #L160 - L161 were not covered by tests