diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 700308013d0b..1b9d4f08b3ea 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1988,7 +1988,7 @@ def get_expected_keys(inputs, keys): self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) - if self.args.enable_sharding_comm_overlap: + if hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap: model.register_sharding_comm_overlap_hook(self.optimizer) # No pipeline mode, sharding only @@ -2764,7 +2764,9 @@ def _load_optimizer_and_scheduler(self, checkpoint): else: opt_state_dict = None else: - model = self.model_wrapped if self.args.enable_sharding_comm_overlap else self.model + model = self.model + if hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap: + model = self.model_wrapped opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer( model=model, optimizer=self.optimizer, diff --git a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py index 17a9f0782221..fda80fca0a61 100644 --- a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py +++ b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py @@ -181,7 +181,7 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check param_shape_info = {} comm_buffer_list = optimizer._inner_opt._comm_buffer_list - if args.enable_sharding_comm_overlap: + if hasattr(args, "enable_sharding_comm_overlap") and args.enable_sharding_comm_overlap: comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values())) model = unwrap_model(model) diff --git a/paddlenlp/trl/dpo_criterion.py b/paddlenlp/trl/dpo_criterion.py index 2af2a8ef2096..be454e2ce4d1 100644 --- a/paddlenlp/trl/dpo_criterion.py +++ b/paddlenlp/trl/dpo_criterion.py @@ -287,10 +287,10 @@ def forward( ) loss = dpo_loss + sft_loss if self.use_infohub: - infohub.policy_chosen_logps.append(policy_chosen_logps) - infohub.policy_rejected_logps.append(policy_rejected_logps) - infohub.sft_loss.append(sft_loss) - infohub.dpo_loss.append(dpo_loss) + infohub.policy_chosen_logps.append(policy_chosen_logps.detach()) + infohub.policy_rejected_logps.append(policy_rejected_logps.detach()) + infohub.sft_loss.append(sft_loss.detach()) + infohub.dpo_loss.append(dpo_loss.detach()) return loss else: return policy_chosen_logps, policy_rejected_logps, sft_loss, dpo_loss, loss diff --git a/paddlenlp/trl/kto_criterion.py b/paddlenlp/trl/kto_criterion.py index 52745e996999..a6ca6c4c837a 100644 --- a/paddlenlp/trl/kto_criterion.py +++ b/paddlenlp/trl/kto_criterion.py @@ -247,10 +247,10 @@ def forward( reference_kl_logps, ) if self.use_infohub: - infohub.policy_chosen_logps.append(policy_chosen_logps) - infohub.policy_rejected_logps.append(policy_rejected_logps) - infohub.policy_kl_logps.append(policy_kl_logps) - infohub.kl.append(kl) + infohub.policy_chosen_logps.append(policy_chosen_logps.detach()) + infohub.policy_rejected_logps.append(policy_rejected_logps.detach()) + infohub.policy_kl_logps.append(policy_kl_logps.detach()) + infohub.kl.append(kl.detach()) return loss else: return (