From bf90fdb0bb4b4ee33ea1f93a9ddceba3d2c73945 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Sep 2024 17:39:23 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo_aligner/algorithms/rpo.py | 29 +++++----- nemo_aligner/data/nlp/builders.py | 2 +- nemo_aligner/data/nlp/datasets.py | 35 ++++++------ .../models/nlp/gpt/megatron_gpt_rpo_model.py | 56 ++++++++++++------- 4 files changed, 69 insertions(+), 53 deletions(-) diff --git a/nemo_aligner/algorithms/rpo.py b/nemo_aligner/algorithms/rpo.py index a7de3d2ac..991f0b8e2 100644 --- a/nemo_aligner/algorithms/rpo.py +++ b/nemo_aligner/algorithms/rpo.py @@ -32,28 +32,28 @@ def rpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False): resp_outputs = {} - + ## assume len 4 for responses for k in batch[0].keys(): - if k.startswith('response_'): + if k.startswith("response_"): # get response_i resp_outputs[k] = torch.nn.utils.rnn.pad_sequence( - [item[k] for item in batch], - batch_first=True, padding_value=eos_id) - elif k.startswith('labels_'): + [item[k] for item in batch], batch_first=True, padding_value=eos_id + ) + elif k.startswith("labels_"): # get labels_i resp_outputs[k] = torch.nn.utils.rnn.pad_sequence( - [item[k] for item in batch], - batch_first=True, padding_value=-100) - elif k.startswith('lengths_'): + [item[k] for item in batch], batch_first=True, padding_value=-100 + ) + elif k.startswith("lengths_"): # get lens_i resp_outputs[k] = torch.LongTensor([item[k] for item in batch]) - elif k.startswith('rewards_'): + elif k.startswith("rewards_"): # get r_i resp_outputs[k] = torch.FloatTensor([item[k] for item in batch]) - + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - resp_outputs['response_1'], eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss, + resp_outputs["response_1"], eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss, ) assert attention_mask.ndim == 4, "attention_mask is incorrect shape for dpo_custom_collate" if attention_mask.shape[0] == 1: @@ -61,10 +61,9 @@ def rpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_ # attention_mask = attention_mask.expand(len(batch), *((-1,) * (len(attention_mask.shape) - 1))) attention_mask = attention_mask.repeat(4, *((1,) * (len(attention_mask.shape) - 1))) - resp_outputs["attention_mask"] = attention_mask resp_outputs["position_ids"] = position_ids - + return resp_outputs @@ -319,8 +318,8 @@ def augment_dataloader(self, dataloader): batch = next(iter_dataloader) logprobs = self.model.get_ref_policy_logprobs(batch).cpu() ind = 1 - - for logps in torch.split(logprobs, len(logprobs) // self.k_len, dim=0): + + for logps in torch.split(logprobs, len(logprobs) // self.k_len, dim=0): batch["ref_policy_log_probs_response_" + str(ind)] = logps ind += 1 diff --git a/nemo_aligner/data/nlp/builders.py b/nemo_aligner/data/nlp/builders.py index ef2bcb54f..e90fb3aeb 100644 --- a/nemo_aligner/data/nlp/builders.py +++ b/nemo_aligner/data/nlp/builders.py @@ -43,11 +43,11 @@ from nemo.utils import logging from nemo_aligner.data.nlp.datasets import ( DPOModelDataset, - RPOModelDataset, KTOModelDataset, RegressionRewardModelDataset, RewardModelDataset, RLHFDataset, + RPOModelDataset, ) from nemo_aligner.utils import parallel_state from nemo_aligner.utils.utils import collate_with_batch_max_sequence_length diff --git a/nemo_aligner/data/nlp/datasets.py b/nemo_aligner/data/nlp/datasets.py index 73679af0a..2a2a4e0fa 100644 --- a/nemo_aligner/data/nlp/datasets.py +++ b/nemo_aligner/data/nlp/datasets.py @@ -15,11 +15,11 @@ """Custom datasets for RLHF training""" import os +import random import numpy as np import scipy import torch -import random from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import _create_ltor_masks_and_position_ids from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset @@ -413,20 +413,20 @@ def __getitem__(self, idx): prompt, prompt_len = self.encode(payload["prompt"], append_eod=False) responses = [] labels = [] - + # loop on responses of the given prompt to encode them - for resp in payload['responses']: - resp_tokens, resp_len = self.encode( - resp, append_eod=self.cfg.data.get("append_eod", False) - ) - + for resp in payload["responses"]: + resp_tokens, resp_len = self.encode(resp, append_eod=self.cfg.data.get("append_eod", False)) + resp_tokens = prompt + resp_tokens resp_len = len(resp_tokens) - + responses.append((resp_tokens, resp_len)) labels.append(([-100] * prompt_len) + resp_tokens[prompt_len:]) - assert resp_tokens[0:prompt_len] == prompt, "the tokenizer for DPO has merged tokens between prompt and response" + assert ( + resp_tokens[0:prompt_len] == prompt + ), "the tokenizer for DPO has merged tokens between prompt and response" max_curr_seq_len = max([i[1] for i in responses]) if max_curr_seq_len > self.seq_length: @@ -437,7 +437,7 @@ def __getitem__(self, idx): rewards = payload.get("rewards", [random.random() for _ in range(len(responses))]) resp_dict = {} - + for ind, (resp, resp_len) in enumerate(responses): resp_tokens = torch.nn.functional.pad( torch.LongTensor(resp), (0, max_curr_seq_len - resp_len), mode="constant", value=self.eos_id @@ -446,20 +446,19 @@ def __getitem__(self, idx): label_tokens = torch.nn.functional.pad( torch.LongTensor(label), (0, max_curr_seq_len - len(label)), mode="constant", value=-100 ) - + # slice if necessary if max_curr_seq_len > self.seq_length: resp_tokens = resp_tokens[: self.nograd_length] label_tokens = torch.ones_like(resp_tokens) * (-100) resp_len = self.nograd_length - - resp_dict['response_' + str(ind+1)] = resp_tokens - resp_dict['labels_' + str(ind+1)] = label_tokens - resp_dict['lengths_' + str(ind+1)] = resp_len - resp_dict['rewards_' + str(ind+1)] = rewards[ind] - - return resp_dict + resp_dict["response_" + str(ind + 1)] = resp_tokens + resp_dict["labels_" + str(ind + 1)] = label_tokens + resp_dict["lengths_" + str(ind + 1)] = resp_len + resp_dict["rewards_" + str(ind + 1)] = rewards[ind] + + return resp_dict class KTOModelDataset(Dataset): diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py index d4ab4a87b..430e48ca6 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py @@ -121,20 +121,22 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ # creating tokens and labels tensor batches tokens, labels, ref_logprobs, gt_rewards = None, None, None, None if batch["response_1"] is not None: - tokens = torch.cat(tuple(batch['response_' + str(i+1)] for i in range(self.k_len)), dim=0) + tokens = torch.cat(tuple(batch["response_" + str(i + 1)] for i in range(self.k_len)), dim=0) if batch["labels_1"] is not None: - labels = torch.cat(tuple(batch['labels_' + str(i+1)] for i in range(self.k_len)), dim=0) + labels = torch.cat(tuple(batch["labels_" + str(i + 1)] for i in range(self.k_len)), dim=0) if batch["rewards_1"] is not None: - gt_rewards = torch.cat(tuple(batch['rewards_' + str(i+1)] for i in range(self.k_len)), dim=0) + gt_rewards = torch.cat(tuple(batch["rewards_" + str(i + 1)] for i in range(self.k_len)), dim=0) if batch.get("ref_policy_log_probs_response_1") is not None: - ref_logprobs = torch.cat(tuple(batch['ref_policy_log_probs_response_' + str(i+1)] for i in range(self.k_len)), dim=0) + ref_logprobs = torch.cat( + tuple(batch["ref_policy_log_probs_response_" + str(i + 1)] for i in range(self.k_len)), dim=0 + ) # this is necessary if MBS > 1 with the new GBS padding logic, as you may get batch dim > 1 in some configs # these two lines ensure your position_ids and attn_mask are always B=1 - attention_mask = batch['attention_mask'][0:1] + attention_mask = batch["attention_mask"][0:1] # Model forward pass forward_args = { @@ -195,7 +197,6 @@ def loss_func(output_tensor): ) loss = self.preference_loss_weight * preference_loss + self.sft_loss_weight * sft_loss - ( reduced_loss, reduced_preference_loss, @@ -245,11 +246,15 @@ def log_sum_exp(self, x): return max_x + torch.log(torch.sum(torch.exp(x - max_x))) def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_probs=False): - if self.preference_loss == 'rpo': + if self.preference_loss == "rpo": # estimated rewards - rewards_pred = torch.stack(self.split_output_tensor(self.get_reduced_masked_logps( - self.beta * (pi_logprobs - ref_logprobs), labels, average_log_probs=average_log_probs - ))) + rewards_pred = torch.stack( + self.split_output_tensor( + self.get_reduced_masked_logps( + self.beta * (pi_logprobs - ref_logprobs), labels, average_log_probs=average_log_probs + ) + ) + ) # based on GT rewards gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) @@ -257,17 +262,29 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p else: raise ValueError("Unknown RPO Loss") - loss = ( torch.nn.functional.softmax(p_star, dim=0) * (torch.nn.functional.log_softmax( p_star, dim=0 ) - torch.nn.functional.log_softmax( rewards_pred, dim=0 )) ).sum(0).mean(0) - + loss = ( + ( + torch.nn.functional.softmax(p_star, dim=0) + * ( + torch.nn.functional.log_softmax(p_star, dim=0) + - torch.nn.functional.log_softmax(rewards_pred, dim=0) + ) + ) + .sum(0) + .mean(0) + ) + # adding accuracy for the best rewards -> MSE or best accuracy? acc_best_resp = (torch.argmax(rewards_pred, dim=0) == torch.argmax(gt_rewards, dim=0)).float().mean() return loss, acc_best_resp def sft_loss_func(self, pi_logprobs, labels, gt_rewards, average_log_probs=False): - logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) # [16] - all_log_probs = torch.stack(self.split_output_tensor(logprobs)) # [4, 4] -> each has several responses which we select the best? - gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) # same, we split the rewards + logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) # [16] + all_log_probs = torch.stack( + self.split_output_tensor(logprobs) + ) # [4, 4] -> each has several responses which we select the best? + gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) # same, we split the rewards chosen_best = torch.argmax(gt_rewards, dim=0) chosen_logprobs = all_log_probs[chosen_best, torch.arange(all_log_probs.size(1))] @@ -288,7 +305,8 @@ def get_loss_and_metrics(self, batch, forward_only): num_microbatches=get_num_microbatches(), forward_only=forward_only, seq_length=seq_length, - micro_batch_size=self.cfg.micro_batch_size * self.k_len, # each minibatch has K comparisons so tensor shape will be mbs * num_responses + micro_batch_size=self.cfg.micro_batch_size + * self.k_len, # each minibatch has K comparisons so tensor shape will be mbs * num_responses ) # only the last stages of the pipeline return losses @@ -385,16 +403,16 @@ def get_logprob_batch(self, batch): collect_non_loss_data=True, ) - each_response_list = [ [] for _ in range(self.k_len) ] + each_response_list = [[] for _ in range(self.k_len)] if len(logprobs_list) > 0: for item in logprobs_list: all_log_probs = self.split_output_tensor(item["logprobs"]) for ind in range(self.k_len): each_response_list[ind].extend(all_log_probs[ind]) - each_response_list = [ torch.stack(b, dim=0) for b in each_response_list ] + each_response_list = [torch.stack(b, dim=0) for b in each_response_list] logprobs = torch.cat(each_response_list, dim=0) - + else: logprobs = None