Skip to content

Commit

Permalink
[llm]update dpo criterion (#9620)
Browse files Browse the repository at this point in the history
* update dpo criterion

* update dpo criterion
  • Loading branch information
lugimzzz authored Dec 16, 2024
1 parent 5e1f01f commit f3ba5b3
Showing 1 changed file with 60 additions and 50 deletions.
110 changes: 60 additions & 50 deletions paddlenlp/trl/dpo_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy
from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp

from paddlenlp.transformers import (
AllGatherVarlenOp,
Expand All @@ -28,7 +29,6 @@
)
from paddlenlp.transformers.model_outputs import CausalLMOutputWithPast
from paddlenlp.utils import infohub
from paddlenlp.utils.tools import get_env_device


class DPOCriterion(nn.Layer):
Expand Down Expand Up @@ -148,16 +148,25 @@ def dpo_logps(
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
labels, sparse_tgt_idx = sequence_parallel_sparse_mask_labels(labels, 0)

hidden_states = paddle.take_along_axis(hidden_states, sparse_tgt_idx, axis=0)
hidden_states = paddle.gather(hidden_states, sparse_tgt_idx, axis=0)
hidden_states = AllGatherVarlenOp.apply(hidden_states)
else:
labels = labels.flatten()
sparse_tgt_idx = paddle.nonzero(labels != 0).flatten()
labels = paddle.take_along_axis(labels, sparse_tgt_idx, axis=0)

hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]])
hidden_states = paddle.take_along_axis(hidden_states, sparse_tgt_idx.unsqueeze(-1), axis=0)

hidden_states = paddle.gather(hidden_states, sparse_tgt_idx, axis=0)
elif use_fused_head_and_loss_fn:
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
hidden_states = GatherOp.apply(hidden_states)
hidden_states = hidden_states.reshape(
[
-1,
self.config.max_sequence_length,
hidden_states.shape[-1],
]
)
if use_fused_head_and_loss_fn:
per_token_logps = -fused_head_and_loss_fn(
hidden_states,
Expand Down Expand Up @@ -194,64 +203,65 @@ def dpo_logps(

if len(response_indexs.shape) == 3:
response_indexs = response_indexs[0]

offset = 1 if self.ignore_eos_token else 0
if use_sparse_head_and_loss_fn:
chosen_logps = paddle.stack(
[(per_token_logps[response_index[1] : response_index[2]]).sum() for response_index in response_indexs],
[
(
paddle.gather(
per_token_logps.reshape([-1]),
paddle.arange(response_index[1], response_index[2], dtype=paddle.int32),
axis=0,
).sum()
)
for response_index in response_indexs
],
axis=0,
)
rejected_logps = paddle.stack(
[(per_token_logps[response_index[2] : response_index[3]]).sum() for response_index in response_indexs],
[
(
paddle.gather(
per_token_logps.reshape([-1]),
paddle.arange(response_index[2] + offset, response_index[3], dtype=paddle.int32),
axis=0,
).sum()
)
for response_index in response_indexs
],
axis=0,
)
else:
if get_env_device() == "npu":
chosen_list = []
for response_index in response_indexs:
begin = response_index[1]
end = response_index[2]
one_data = paddle.ones_like(per_token_logps[0])
mask_data = paddle.zeros_like(per_token_logps[0])
paddle.assign(one_data._slice(begin, end), mask_data._slice(begin, end))
chosen_list.append((per_token_logps[0] * mask_data).sum())
chosen_logps = paddle.stack(chosen_list, axis=0)
rejected_list = []
for response_index in response_indexs:
begin = response_index[2]
if self.ignore_eos_token:
begin += 1
end = response_index[3]
one_data = paddle.ones_like(per_token_logps[0])
mask_data = paddle.zeros_like(per_token_logps[0])
paddle.assign(one_data._slice(begin, end), mask_data._slice(begin, end))
rejected_list.append((per_token_logps[0] * mask_data).sum())
rejected_logps = paddle.stack(rejected_list, axis=0)
else:
chosen_logps = paddle.stack(
[
(per_token_logps[response_index[0]][response_index[1] : response_index[2]]).sum()
for response_index in response_indexs
],
axis=0,
)
if self.ignore_eos_token:
rejected_logps = paddle.stack(
[
(per_token_logps[response_index[0]][response_index[2] + 1 : response_index[3]]).sum()
for response_index in response_indexs
],
axis=0,
chosen_logps = paddle.stack(
[
(
paddle.gather(
paddle.gather(per_token_logps, response_index[0], axis=0),
paddle.arange(response_index[1], response_index[2], dtype=paddle.int32),
axis=0,
).sum()
)
else:
rejected_logps = paddle.stack(
[
(per_token_logps[response_index[0]][response_index[2] : response_index[3]]).sum()
for response_index in response_indexs
],
axis=0,
for response_index in response_indexs
],
axis=0,
)
rejected_logps = paddle.stack(
[
(
paddle.gather(
paddle.gather(per_token_logps, response_index[0], axis=0),
paddle.arange(response_index[2] + offset, response_index[3], dtype=paddle.int32),
axis=0,
).sum()
)
for response_index in response_indexs
],
axis=0,
)
sft_loss = -chosen_logps.sum() / (chosen_labels != 0).sum()
if average_log_prob:
chosen_response_length = response_indexs[:, 2] - response_indexs[:, 1]
chosen_response_length = response_indexs[:, 2] - response_indexs[:, 1] - offset
rejected_response_length = response_indexs[:, 3] - response_indexs[:, 2]
chosen_logps /= chosen_response_length.astype("float32")
rejected_logps /= rejected_response_length.astype("float32")
Expand Down

0 comments on commit f3ba5b3

Please sign in to comment.