Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 24, 2024
1 parent b2145fd commit bf90fdb
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 53 deletions.
29 changes: 14 additions & 15 deletions nemo_aligner/algorithms/rpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,38 @@

def rpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):
resp_outputs = {}

## assume len 4 for responses
for k in batch[0].keys():
if k.startswith('response_'):
if k.startswith("response_"):
# get response_i
resp_outputs[k] = torch.nn.utils.rnn.pad_sequence(
[item[k] for item in batch],
batch_first=True, padding_value=eos_id)
elif k.startswith('labels_'):
[item[k] for item in batch], batch_first=True, padding_value=eos_id
)
elif k.startswith("labels_"):
# get labels_i
resp_outputs[k] = torch.nn.utils.rnn.pad_sequence(
[item[k] for item in batch],
batch_first=True, padding_value=-100)
elif k.startswith('lengths_'):
[item[k] for item in batch], batch_first=True, padding_value=-100
)
elif k.startswith("lengths_"):
# get lens_i
resp_outputs[k] = torch.LongTensor([item[k] for item in batch])
elif k.startswith('rewards_'):
elif k.startswith("rewards_"):
# get r_i
resp_outputs[k] = torch.FloatTensor([item[k] for item in batch])

attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
resp_outputs['response_1'], eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss,
resp_outputs["response_1"], eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss,
)
assert attention_mask.ndim == 4, "attention_mask is incorrect shape for dpo_custom_collate"
if attention_mask.shape[0] == 1:
# using .expand() here causes errors from pin_memory=True, so need to use .repeat()
# attention_mask = attention_mask.expand(len(batch), *((-1,) * (len(attention_mask.shape) - 1)))
attention_mask = attention_mask.repeat(4, *((1,) * (len(attention_mask.shape) - 1)))


resp_outputs["attention_mask"] = attention_mask
resp_outputs["position_ids"] = position_ids

return resp_outputs


Expand Down Expand Up @@ -319,8 +318,8 @@ def augment_dataloader(self, dataloader):
batch = next(iter_dataloader)
logprobs = self.model.get_ref_policy_logprobs(batch).cpu()
ind = 1
for logps in torch.split(logprobs, len(logprobs) // self.k_len, dim=0):

for logps in torch.split(logprobs, len(logprobs) // self.k_len, dim=0):
batch["ref_policy_log_probs_response_" + str(ind)] = logps
ind += 1

Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/data/nlp/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@
from nemo.utils import logging
from nemo_aligner.data.nlp.datasets import (
DPOModelDataset,
RPOModelDataset,
KTOModelDataset,
RegressionRewardModelDataset,
RewardModelDataset,
RLHFDataset,
RPOModelDataset,
)
from nemo_aligner.utils import parallel_state
from nemo_aligner.utils.utils import collate_with_batch_max_sequence_length
Expand Down
35 changes: 17 additions & 18 deletions nemo_aligner/data/nlp/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
"""Custom datasets for RLHF training"""

import os
import random

import numpy as np
import scipy
import torch
import random

from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import _create_ltor_masks_and_position_ids
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset
Expand Down Expand Up @@ -413,20 +413,20 @@ def __getitem__(self, idx):
prompt, prompt_len = self.encode(payload["prompt"], append_eod=False)
responses = []
labels = []

# loop on responses of the given prompt to encode them
for resp in payload['responses']:
resp_tokens, resp_len = self.encode(
resp, append_eod=self.cfg.data.get("append_eod", False)
)

for resp in payload["responses"]:
resp_tokens, resp_len = self.encode(resp, append_eod=self.cfg.data.get("append_eod", False))

resp_tokens = prompt + resp_tokens
resp_len = len(resp_tokens)

responses.append((resp_tokens, resp_len))
labels.append(([-100] * prompt_len) + resp_tokens[prompt_len:])

assert resp_tokens[0:prompt_len] == prompt, "the tokenizer for DPO has merged tokens between prompt and response"
assert (
resp_tokens[0:prompt_len] == prompt
), "the tokenizer for DPO has merged tokens between prompt and response"

max_curr_seq_len = max([i[1] for i in responses])
if max_curr_seq_len > self.seq_length:
Expand All @@ -437,7 +437,7 @@ def __getitem__(self, idx):

rewards = payload.get("rewards", [random.random() for _ in range(len(responses))])
resp_dict = {}

for ind, (resp, resp_len) in enumerate(responses):
resp_tokens = torch.nn.functional.pad(
torch.LongTensor(resp), (0, max_curr_seq_len - resp_len), mode="constant", value=self.eos_id
Expand All @@ -446,20 +446,19 @@ def __getitem__(self, idx):
label_tokens = torch.nn.functional.pad(
torch.LongTensor(label), (0, max_curr_seq_len - len(label)), mode="constant", value=-100
)

# slice if necessary
if max_curr_seq_len > self.seq_length:
resp_tokens = resp_tokens[: self.nograd_length]
label_tokens = torch.ones_like(resp_tokens) * (-100)
resp_len = self.nograd_length

resp_dict['response_' + str(ind+1)] = resp_tokens
resp_dict['labels_' + str(ind+1)] = label_tokens
resp_dict['lengths_' + str(ind+1)] = resp_len
resp_dict['rewards_' + str(ind+1)] = rewards[ind]

return resp_dict

resp_dict["response_" + str(ind + 1)] = resp_tokens
resp_dict["labels_" + str(ind + 1)] = label_tokens
resp_dict["lengths_" + str(ind + 1)] = resp_len
resp_dict["rewards_" + str(ind + 1)] = rewards[ind]

return resp_dict


class KTOModelDataset(Dataset):
Expand Down
56 changes: 37 additions & 19 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,22 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
# creating tokens and labels tensor batches
tokens, labels, ref_logprobs, gt_rewards = None, None, None, None
if batch["response_1"] is not None:
tokens = torch.cat(tuple(batch['response_' + str(i+1)] for i in range(self.k_len)), dim=0)
tokens = torch.cat(tuple(batch["response_" + str(i + 1)] for i in range(self.k_len)), dim=0)

if batch["labels_1"] is not None:
labels = torch.cat(tuple(batch['labels_' + str(i+1)] for i in range(self.k_len)), dim=0)
labels = torch.cat(tuple(batch["labels_" + str(i + 1)] for i in range(self.k_len)), dim=0)

if batch["rewards_1"] is not None:
gt_rewards = torch.cat(tuple(batch['rewards_' + str(i+1)] for i in range(self.k_len)), dim=0)
gt_rewards = torch.cat(tuple(batch["rewards_" + str(i + 1)] for i in range(self.k_len)), dim=0)

if batch.get("ref_policy_log_probs_response_1") is not None:
ref_logprobs = torch.cat(tuple(batch['ref_policy_log_probs_response_' + str(i+1)] for i in range(self.k_len)), dim=0)
ref_logprobs = torch.cat(
tuple(batch["ref_policy_log_probs_response_" + str(i + 1)] for i in range(self.k_len)), dim=0
)

# this is necessary if MBS > 1 with the new GBS padding logic, as you may get batch dim > 1 in some configs
# these two lines ensure your position_ids and attn_mask are always B=1
attention_mask = batch['attention_mask'][0:1]
attention_mask = batch["attention_mask"][0:1]

# Model forward pass
forward_args = {
Expand Down Expand Up @@ -195,7 +197,6 @@ def loss_func(output_tensor):
)
loss = self.preference_loss_weight * preference_loss + self.sft_loss_weight * sft_loss


(
reduced_loss,
reduced_preference_loss,
Expand Down Expand Up @@ -245,29 +246,45 @@ def log_sum_exp(self, x):
return max_x + torch.log(torch.sum(torch.exp(x - max_x)))

def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_probs=False):
if self.preference_loss == 'rpo':
if self.preference_loss == "rpo":
# estimated rewards
rewards_pred = torch.stack(self.split_output_tensor(self.get_reduced_masked_logps(
self.beta * (pi_logprobs - ref_logprobs), labels, average_log_probs=average_log_probs
)))
rewards_pred = torch.stack(
self.split_output_tensor(
self.get_reduced_masked_logps(
self.beta * (pi_logprobs - ref_logprobs), labels, average_log_probs=average_log_probs
)
)
)

# based on GT rewards
gt_rewards = torch.stack(self.split_output_tensor(gt_rewards))
p_star = self.eta * gt_rewards
else:
raise ValueError("Unknown RPO Loss")

loss = ( torch.nn.functional.softmax(p_star, dim=0) * (torch.nn.functional.log_softmax( p_star, dim=0 ) - torch.nn.functional.log_softmax( rewards_pred, dim=0 )) ).sum(0).mean(0)

loss = (
(
torch.nn.functional.softmax(p_star, dim=0)
* (
torch.nn.functional.log_softmax(p_star, dim=0)
- torch.nn.functional.log_softmax(rewards_pred, dim=0)
)
)
.sum(0)
.mean(0)
)

# adding accuracy for the best rewards -> MSE or best accuracy?
acc_best_resp = (torch.argmax(rewards_pred, dim=0) == torch.argmax(gt_rewards, dim=0)).float().mean()

return loss, acc_best_resp

def sft_loss_func(self, pi_logprobs, labels, gt_rewards, average_log_probs=False):
logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) # [16]
all_log_probs = torch.stack(self.split_output_tensor(logprobs)) # [4, 4] -> each has several responses which we select the best?
gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) # same, we split the rewards
logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) # [16]
all_log_probs = torch.stack(
self.split_output_tensor(logprobs)
) # [4, 4] -> each has several responses which we select the best?
gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) # same, we split the rewards
chosen_best = torch.argmax(gt_rewards, dim=0)

chosen_logprobs = all_log_probs[chosen_best, torch.arange(all_log_probs.size(1))]
Expand All @@ -288,7 +305,8 @@ def get_loss_and_metrics(self, batch, forward_only):
num_microbatches=get_num_microbatches(),
forward_only=forward_only,
seq_length=seq_length,
micro_batch_size=self.cfg.micro_batch_size * self.k_len, # each minibatch has K comparisons so tensor shape will be mbs * num_responses
micro_batch_size=self.cfg.micro_batch_size
* self.k_len, # each minibatch has K comparisons so tensor shape will be mbs * num_responses
)

# only the last stages of the pipeline return losses
Expand Down Expand Up @@ -385,16 +403,16 @@ def get_logprob_batch(self, batch):
collect_non_loss_data=True,
)

each_response_list = [ [] for _ in range(self.k_len) ]
each_response_list = [[] for _ in range(self.k_len)]

if len(logprobs_list) > 0:
for item in logprobs_list:
all_log_probs = self.split_output_tensor(item["logprobs"])
for ind in range(self.k_len):
each_response_list[ind].extend(all_log_probs[ind])
each_response_list = [ torch.stack(b, dim=0) for b in each_response_list ]
each_response_list = [torch.stack(b, dim=0) for b in each_response_list]
logprobs = torch.cat(each_response_list, dim=0)

else:
logprobs = None

Expand Down

0 comments on commit bf90fdb

Please sign in to comment.