diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index f52e662918eb..0bfa61364ece 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -165,6 +165,7 @@ def __init__(self, model, lora_config: LoRAConfig) -> None: self.forward = self.model.forward if lora_config.loraga: self.loraga_init_dict = {} + self.reinit_base_model = False logger.info("Mark only lora and trainable_module as trainable.") self.mark_only_lora_as_trainable() @@ -349,11 +350,11 @@ def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict): ) base_name = name.replace("lora_A", "weight") - - # Reinit base model - offset = init_loraA.cuda() @ init_loraB.cuda() - ori_weight = model_state_dict[base_name] - model_state_dict[base_name].set_value(ori_weight - self.lora_config.scaling * offset) + if not self.reinit_base_model: + # Reinit base model + offset = init_loraA.cuda() @ init_loraB.cuda() + ori_weight = model_state_dict[base_name] + model_state_dict[base_name].set_value(ori_weight - self.lora_config.scaling * offset) del model_state_dict gc.collect() self.model.set_state_dict(state_dict) diff --git a/paddlenlp/peft/lora/loraga_utils.py b/paddlenlp/peft/lora/loraga_utils.py index 27cba7770316..5c821772ea32 100644 --- a/paddlenlp/peft/lora/loraga_utils.py +++ b/paddlenlp/peft/lora/loraga_utils.py @@ -74,10 +74,9 @@ def estimate_gradient(self, model: PretrainedModel): ): for batch in dataloader: iters += 1 - batch = {k: paddle.to_tensor(v) for k, v in batch.items()} - # Pipeline parallel not supported currently - loss, logits = model(**batch) + with paddle.amp.auto_cast(enable=True, custom_black_list=self.args.amp_custom_black_list): + loss, logits = model(**batch) loss.backward() if iters == self.loraga_init_iters: @@ -160,13 +159,21 @@ def get_module_gradient( rank_suffix = "_" + str(local_rank) local_grad_name = ".".join(grad_name.split(".")[1:]) + ".weight" + rank_suffix gradient = gradient_dict.pop(local_grad_name).cuda() + + is_fleet_init = True + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + sharding_parallel_group = hcg.get_sharding_parallel_group() + data_parallel_group = hcg.get_data_parallel_group() + except: + is_fleet_init = False + if tp_degree > 1: # remove prefix and suffix in name model_split_key = local_grad_name.split(base_model_prefix)[-1].rsplit(rank_suffix, 1)[0] if model_split_key in base_model_split_mappings: merge_func = base_model_split_mappings[model_split_key] - hcg = fleet.get_hybrid_communicate_group() - model_parallel_group = hcg.get_model_parallel_group() output_tensors = [] dist.all_gather(output_tensors, gradient, group=model_parallel_group) @@ -175,18 +182,17 @@ def get_module_gradient( # sharding if sharding_degree > 1: - hcg = fleet.get_hybrid_communicate_group() - sharding_parallel_group = hcg.get_sharding_parallel_group() if sharding_parallel_group.nranks > 1: - dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=sharding_parallel_group) gradient /= sharding_parallel_group.nranks + # dp if dp_degree > 1: - hcg = fleet.get_hybrid_communicate_group() - data_parallel_group = hcg.get_data_parallel_group() if data_parallel_group.nranks > 1: - dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=data_parallel_group) + if is_fleet_init: + dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=data_parallel_group) + else: + dist.all_reduce(gradient, op=dist.ReduceOp.SUM) gradient /= data_parallel_group.nranks return gradient @@ -250,6 +256,7 @@ def loraga_svd_reinit( lora_split_mapping, **kwargs, ) + model.reinit_base_model = True model.loraga_init_dict = loraga_init_dict diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 57c655736f25..9859927d6946 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -623,6 +623,8 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): self.model, resume_from_checkpoint, ) + if isinstance(self.model, LoRAModel) and self.model.lora_config.loraga: + self.model.reinit_base_model = True logger.info(f"Loading model from {resume_from_checkpoint} using unified checkpoint.") self.runtime_timer.stop() return @@ -635,6 +637,8 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): or isinstance(self.model, ReFTModel) ): self._load_from_peft_checkpoint(resume_from_checkpoint) + if isinstance(self.model, LoRAModel) and self.model.lora_config.loraga: + self.model.reinit_base_model = True self.runtime_timer.stop() return