Skip to content

Commit

Permalink
Degert/fix dpo capitalisation (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
trias702 authored Dec 2, 2023
1 parent 1cf5657 commit 8ca2e04
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def get_ref_policy_logprobs(self, list_of_batches):
[torch.cat((b["chosen_labels"], b["rejected_labels"]), dim=0) for b in list_of_batches], dim=0
)
global_batch = [tokens, masks, pos_ids, labels]
with cpu_weight_swap(self, self.ref_policy_state_dict, megatron_amp_O2=self.megatron_amp_o2):
with cpu_weight_swap(self, self.ref_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2):
ref_log_probs = self.get_logprob_batch(global_batch)

# return in GPU, trainer needs to move to cpu
Expand Down

0 comments on commit 8ca2e04

Please sign in to comment.