Skip to content

Commit

Permalink
Merge pull request #4 from gongel/ppo-4d
Browse files Browse the repository at this point in the history
[PPO] Format code
  • Loading branch information
guoshengCS authored Jun 12, 2024
2 parents 1feb5ee + c8b3c61 commit 1b8e4a3
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 36 deletions.
1 change: 1 addition & 0 deletions examples/RLHF/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from __future__ import annotations

import copy
import inspect
import types
Expand Down
9 changes: 3 additions & 6 deletions examples/RLHF/ppo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
# os.environ["https_proxy"] = "agent.baidu.com:8118"

from dataclasses import dataclass, field
from typing import Any, Dict, Tuple
from functools import partial
from typing import Any, Dict, Tuple

import paddle
from data import PromptOnlyDataset, SupervisedDataset, parse_dataset
Expand Down Expand Up @@ -151,9 +151,7 @@ class TrainingArguments(TrainingArguments):
)
use_fusemt: bool = field(
default=True,
metadata={
"help": "use inference model to speedup in rollout generation"
},
metadata={"help": "use inference model to speedup in rollout generation"},
)

# save_generation_output: bool = field(
Expand Down Expand Up @@ -485,8 +483,7 @@ def main():
)
if ptx_ds is not None:
# PretrainingCriterion requires shifted inputs and labels
ptx_ds.get_collator = types.MethodType(
partial(ptx_ds.get_collator.__func__, shift=True), ptx_ds)
ptx_ds.get_collator = types.MethodType(partial(ptx_ds.get_collator.__func__, shift=True), ptx_ds)

# offload
# cleanup actor_eval_model, reward_critic_eval_model
Expand Down
56 changes: 27 additions & 29 deletions examples/RLHF/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,11 +1005,10 @@ def init_train_model_opt(
return policy_model, value_model

def get_epoch_iterator(self):

def gen_epoch_data():
for prompt_only_batch, ptx_batch in zip(
self.prompt_only_dataloader,
itertools.cycle(self.ptx_dataloader),
self.prompt_only_dataloader,
itertools.cycle(self.ptx_dataloader),
):
# generate batches
self.set_eval()
Expand Down Expand Up @@ -1037,10 +1036,11 @@ def __iter__(self):

def __len__(self):
return len(self.prompt_only_dataloader) * (
self.args.update_iters *
self.args.per_device_prompt_batch_size *
self.args.num_return_sequences //
self.args.per_device_train_batch_size)
self.args.update_iters
* self.args.per_device_prompt_batch_size
* self.args.num_return_sequences
// self.args.per_device_train_batch_size
)

return EpochIterator()

Expand All @@ -1051,10 +1051,13 @@ def init_train_num(self: Trainer, train_dataloader: DataLoader):
len_dataloader = None
if not self._is_iterable_dataset(self.train_dataset):
len_dataloader = len(train_dataloader)
num_train_sub_steps = (len_dataloader * self.args.update_iters *
self.args.per_device_prompt_batch_size *
self.args.num_return_sequences //
self.args.per_device_train_batch_size)
num_train_sub_steps = (
len_dataloader
* self.args.update_iters
* self.args.per_device_prompt_batch_size
* self.args.num_return_sequences
// self.args.per_device_train_batch_size
)
num_update_steps_per_epoch = num_train_sub_steps // args.gradient_accumulation_steps
num_examples = len(self.train_dataset)
if args.max_steps > 0:
Expand Down Expand Up @@ -1116,18 +1119,15 @@ def train(

if self.use_ptx:
with guard_set_args(
args,
args,
{
"per_device_train_batch_size":
1 if getattr(self.ptx_dataset, "is_intokens", False) else
self.args.per_device_prompt_batch_size *
self.args.num_return_sequences
"per_device_train_batch_size": 1
if getattr(self.ptx_dataset, "is_intokens", False)
else self.args.per_device_prompt_batch_size * self.args.num_return_sequences
},
), guard_set_args(
self, {
"train_dataset": self.ptx_dataset,
"data_collator": self.ptx_dataset.get_collator()
}):
self, {"train_dataset": self.ptx_dataset, "data_collator": self.ptx_dataset.get_collator()}
):
self.ptx_dataloader = self.get_train_dataloader()
else:
self.ptx_dataloader = range(100)
Expand Down Expand Up @@ -1205,10 +1205,13 @@ def train(
if self.use_ptx:
logger.info("Doing ptx step...")
self.timers and self.timers("ptx_step").start()
with guard_set_args(self._model_config, {
with guard_set_args(
self._model_config,
{
# "set_attn_func": True,
# "use_flash_attention": True
}):
},
):
ptx_info = self.ptx_step(ptx_batch)
rl_info.update(ptx_info)
self.timers and self.timers("ptx_step").stop()
Expand Down Expand Up @@ -1299,13 +1302,8 @@ def add_kl_divergence_regularization(
max=self.clip_range_score,
)
# TODO(guosheng): use scatter_add/put_along_axis
index = paddle.cumsum(sequence_mask.cast(paddle.int64),
axis=-1).argmax(-1, keepdim=True)
rewards = paddle.put_along_axis(rewards,
index,
reward_clip.unsqueeze(axis=-1),
axis=-1,
reduce="add")
index = paddle.cumsum(sequence_mask.cast(paddle.int64), axis=-1).argmax(-1, keepdim=True)
rewards = paddle.put_along_axis(rewards, index, reward_clip.unsqueeze(axis=-1), axis=-1, reduce="add")
# batch_size = log_probs.shape[0]
# for i in range(batch_size):
# # print("="*20, sequence_mask[i])
Expand Down
1 change: 1 addition & 0 deletions examples/RLHF/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from __future__ import annotations

import inspect
import os
import time
Expand Down
1 change: 0 additions & 1 deletion paddlenlp/trainer/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
nested_broadcast_tensor,
nested_empty_tensor,
nested_reduce_tensor,
nested_broadcast_tensor_with_empty,
)

__all__ = [
Expand Down

0 comments on commit 1b8e4a3

Please sign in to comment.