diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 6414dbe9a..f37e8c2db 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -365,7 +365,7 @@ def get_ref_policy_logprobs(self, list_of_batches): [torch.cat((b["chosen_labels"], b["rejected_labels"]), dim=0) for b in list_of_batches], dim=0 ) global_batch = [tokens, masks, pos_ids, labels] - with cpu_weight_swap(self, self.ref_policy_state_dict, megatron_amp_O2=self.megatron_amp_o2): + with cpu_weight_swap(self, self.ref_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2): ref_log_probs = self.get_logprob_batch(global_batch) # return in GPU, trainer needs to move to cpu