diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 700308013d0b..8d88e26ea37b 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2764,7 +2764,9 @@ def _load_optimizer_and_scheduler(self, checkpoint): else: opt_state_dict = None else: - model = self.model_wrapped if self.args.enable_sharding_comm_overlap else self.model + model = self.model + if hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap: + model = self.model_wrapped opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer( model=model, optimizer=self.optimizer, diff --git a/paddlenlp/trl/dpo_criterion.py b/paddlenlp/trl/dpo_criterion.py index 2af2a8ef2096..be454e2ce4d1 100644 --- a/paddlenlp/trl/dpo_criterion.py +++ b/paddlenlp/trl/dpo_criterion.py @@ -287,10 +287,10 @@ def forward( ) loss = dpo_loss + sft_loss if self.use_infohub: - infohub.policy_chosen_logps.append(policy_chosen_logps) - infohub.policy_rejected_logps.append(policy_rejected_logps) - infohub.sft_loss.append(sft_loss) - infohub.dpo_loss.append(dpo_loss) + infohub.policy_chosen_logps.append(policy_chosen_logps.detach()) + infohub.policy_rejected_logps.append(policy_rejected_logps.detach()) + infohub.sft_loss.append(sft_loss.detach()) + infohub.dpo_loss.append(dpo_loss.detach()) return loss else: return policy_chosen_logps, policy_rejected_logps, sft_loss, dpo_loss, loss diff --git a/paddlenlp/trl/kto_criterion.py b/paddlenlp/trl/kto_criterion.py index 52745e996999..a6ca6c4c837a 100644 --- a/paddlenlp/trl/kto_criterion.py +++ b/paddlenlp/trl/kto_criterion.py @@ -247,10 +247,10 @@ def forward( reference_kl_logps, ) if self.use_infohub: - infohub.policy_chosen_logps.append(policy_chosen_logps) - infohub.policy_rejected_logps.append(policy_rejected_logps) - infohub.policy_kl_logps.append(policy_kl_logps) - infohub.kl.append(kl) + infohub.policy_chosen_logps.append(policy_chosen_logps.detach()) + infohub.policy_rejected_logps.append(policy_rejected_logps.detach()) + infohub.policy_kl_logps.append(policy_kl_logps.detach()) + infohub.kl.append(kl.detach()) return loss else: return (