From d8671dfd883c8a0744084ccff31047c55dc3811b Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Sun, 20 Oct 2024 16:54:39 +0000 Subject: [PATCH 1/9] Support LoRA-GA initialization --- llm/run_finetune.py | 41 +++++++++++++++++ llm/utils/argument.py | 4 ++ paddlenlp/peft/lora/lora_config.py | 4 ++ paddlenlp/utils/llm_utils.py | 74 ++++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+) diff --git a/llm/run_finetune.py b/llm/run_finetune.py index d084a910ff65..ecbd8d6e26eb 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -418,6 +418,42 @@ def neft_post_hook(module, input, output): else None ) + if model_args.loraga: + from paddlenlp.utils.llm_utils import estimate_gradient + + if ( + training_args.pipeline_parallel_degree > 1 + or training_args.sequence_parallel + or training_args.autotuner_benchmark + or data_args.zero_padding + or data_args.pad_to_max_length + ): + # NOTE(gongenlei): new add autotuner_benchmark + max_length = data_args.max_length + padding = "max_length" + else: + max_length = None + padding = True + + data_collator = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + max_length=max_length, + padding=padding, + max_label_length=max_length, + return_tensors="np", + return_attention_mask=not model_args.flash_mask, + pad_to_multiple_of=data_args.pad_to_multiple_of, + ) + named_grads = estimate_gradient( + model=model, + train_ds=train_ds, + data_collator=data_collator, + world_size=training_args.world_size, + batch_size=model_args.loraga_init_bsz, + iters=model_args.loraga_init_iters, + tokenizer=tokenizer, + ) + if model_args.prefix_tuning: if training_args.pipeline_parallel_degree > 1: raise NotImplementedError("Prefix tuning is not implemented for pipeline parallelism.") @@ -445,6 +481,11 @@ def neft_post_hook(module, input, output): ) model.print_trainable_parameters() + if model_args.loraga: + from paddlenlp.utils.llm_utils import loraga_reinit + + loraga_reinit(model, named_grads, stable_gamma=model_args.loraga_stable_gamma) + if model_args.lora: if training_args.sharding_parallel_degree > 1: assert ( diff --git a/llm/utils/argument.py b/llm/utils/argument.py index 60b6f89b3377..4a45fc348154 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -214,6 +214,10 @@ class ModelArgument: rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"}) pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"}) + loraga: bool = field(default=False, metadata={"help": "Whether to use LoRA-GA:https://arxiv.org/pdf/2407.05000"}) + loraga_init_bsz: int = field(default=1, metadata={"help": "The batch size for lora ga"}) + loraga_init_iters: int = field(default=32, metadata={"help": "The number of init iterations for lora ga"}) + loraga_stable_gamma: int = field(default=64, metadata={"help": "Lora Ga stable gamma"}) # vera related parameters vera: bool = field(default=False, metadata={"help": "Whether to use vera technique"}) diff --git a/paddlenlp/peft/lora/lora_config.py b/paddlenlp/peft/lora/lora_config.py index 40b59e5c1a17..7785e2046c03 100644 --- a/paddlenlp/peft/lora/lora_config.py +++ b/paddlenlp/peft/lora/lora_config.py @@ -76,6 +76,10 @@ class LoRAConfig: do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"}) rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"}) + loraga: bool = field(default=False, metadata={"help": "Whether to use LoRA-GA:https://arxiv.org/pdf/2407.05000"}) + loraga_init_bsz: int = field(default=1, metadata={"help": "The batch size for lora ga"}) + loraga_init_iters: int = field(default=4, metadata={"help": "The number of init iterations for lora ga"}) + loraga_stable_gamma: int = field(default=32, metadata={"help": "Lora Ga stable gamma"}) lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+"}) base_model_name_or_path: Optional[str] = field( default=None, metadata={"help": "The name of the base model to use."} diff --git a/paddlenlp/utils/llm_utils.py b/paddlenlp/utils/llm_utils.py index 6ef5aae9dfb0..f1e6ee42f388 100644 --- a/paddlenlp/utils/llm_utils.py +++ b/paddlenlp/utils/llm_utils.py @@ -873,3 +873,77 @@ def get_eos_token_id( eos_token_ids_dict = {str(item): item for item in eos_token_ids} return list(eos_token_ids_dict.values()) + + +def estimate_gradient(model, train_ds, data_collator, world_size=1, batch_size=1, iters=4, tokenizer=None): + """Estimate the gradient of the model on the given dataset""" + + logger.info("Estimating gradient for LoraGA") + model.train() + named_grads = {} + dataloader = DataLoader(train_ds, collate_fn=data_collator, batch_size=batch_size, shuffle=False) + num_batch = 0 + for batch in dataloader: + num_batch += 1 + batch = {k: paddle.to_tensor(v) for k, v in batch.items()} + loss, logits = model(**batch) + loss.backward() + + # Record gradients + for grad_name, param in model.named_parameters(): + # if param.stop_gradient is False: + if param.stop_gradient is False and param.grad is not None: + if grad_name not in named_grads: + named_grads[grad_name] = param.grad.clone() + else: + named_grads[grad_name] += param.grad + + param.clear_gradient() + if num_batch == iters: + break + + for grad_name, param in named_grads.items(): + named_grads[grad_name] /= num_batch + + paddle.device.cuda.empty_cache() + return named_grads + + +def loraga_reinit(model, named_grads, stable_gamma, **kwargs): + """Re-initialize the weights of the model using the estimated gradients""" + from tqdm import tqdm + + for name, module in tqdm( + model.named_sublayers(), + desc="Reinitializing Lora", + total=len(list(model.named_sublayers())), + ): + from paddlenlp.peft.lora.lora_layers import LoRALinear + + if isinstance(module, LoRALinear): + loraga_reinit_modules(name, module, named_grads, stable_gamma, **kwargs) + + +def loraga_reinit_modules(name, module, named_grads, stable_gamma, **kwargs): + with paddle.no_grad(): + lora_r = module.r + grad_name = ".".join(name.split(".")[1:]) + ".weight" + grads = named_grads[grad_name] + + U, S, V = paddle.linalg.svd_lowrank(grads.cuda().astype("float32"), q=4 * lora_r, niter=4) + + V = V.T + A = U[:, lora_r : 2 * lora_r] + B = V[:lora_r, :] + m, n = grads.shape # m: feature_out, n: feature_in + # If stable_gamma is not -1, scale the matrices A and B by the square root of the stable_gamma + if stable_gamma != -1: + A = A * m**0.25 / stable_gamma**0.5 + B = B * m**0.25 / stable_gamma**0.5 + else: + A = A / module.scaling + B = B / module.scaling + module.lora_A.set_value(A.astype(module.lora_A.dtype)) + module.lora_B.set_value(B.astype(module.lora_B.dtype)) + offset = module.lora_A @ module.lora_B + module.weight.data -= module.scaling * offset From 6c2da73c5991eb20dee3643f2cab0fc4fa89d3d6 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Mon, 21 Oct 2024 10:57:29 +0000 Subject: [PATCH 2/9] modify loraga_reinit --- llm/run_finetune.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/llm/run_finetune.py b/llm/run_finetune.py index b36763b4fbed..1e4b55014f46 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -109,6 +109,7 @@ def main(): if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: try: from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + LinearConfig.enable_accumulate_steps_opt() LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) except ImportError: @@ -480,11 +481,6 @@ def neft_post_hook(module, input, output): ) model.print_trainable_parameters() - if model_args.loraga: - from paddlenlp.utils.llm_utils import loraga_reinit - - loraga_reinit(model, named_grads, stable_gamma=model_args.loraga_stable_gamma) - if model_args.lora: if training_args.sharding_parallel_degree > 1: assert ( @@ -509,7 +505,10 @@ def neft_post_hook(module, input, output): model = LoRAModel(model, lora_config) else: model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path) + if model_args.loraga: + from paddlenlp.utils.llm_utils import loraga_reinit + loraga_reinit(model, named_grads, stable_gamma=model_args.loraga_stable_gamma) model.print_trainable_parameters() def compute_metrics_do_generation(eval_preds): From f6ee62298f6035da2728e67073fe006726312b79 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Thu, 7 Nov 2024 08:15:45 +0000 Subject: [PATCH 3/9] Support multi GPU initialization --- llm/run_finetune.py | 6 +- llm/utils/argument.py | 1 - paddlenlp/peft/lora/lora_config.py | 1 - paddlenlp/trl/llm_utils.py | 242 ++++++++++++++++++++++++++--- 4 files changed, 219 insertions(+), 31 deletions(-) diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 5f075f05349f..33faca83c590 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -450,10 +450,8 @@ def neft_post_hook(module, input, output): model=model, train_ds=train_ds, data_collator=data_collator, - world_size=training_args.world_size, - batch_size=model_args.loraga_init_bsz, - iters=model_args.loraga_init_iters, - tokenizer=tokenizer, + training_args=training_args, + loraga_init_iters=model_args.loraga_init_iters, ) if model_args.prefix_tuning: diff --git a/llm/utils/argument.py b/llm/utils/argument.py index b089ee4e1247..38d5fc318208 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -210,7 +210,6 @@ class ModelArgument: lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"}) pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"}) loraga: bool = field(default=False, metadata={"help": "Whether to use LoRA-GA:https://arxiv.org/pdf/2407.05000"}) - loraga_init_bsz: int = field(default=1, metadata={"help": "The batch size for lora ga"}) loraga_init_iters: int = field(default=32, metadata={"help": "The number of init iterations for lora ga"}) loraga_stable_gamma: int = field(default=64, metadata={"help": "Lora Ga stable gamma"}) diff --git a/paddlenlp/peft/lora/lora_config.py b/paddlenlp/peft/lora/lora_config.py index 7785e2046c03..b1ab50e92e63 100644 --- a/paddlenlp/peft/lora/lora_config.py +++ b/paddlenlp/peft/lora/lora_config.py @@ -77,7 +77,6 @@ class LoRAConfig: rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"}) loraga: bool = field(default=False, metadata={"help": "Whether to use LoRA-GA:https://arxiv.org/pdf/2407.05000"}) - loraga_init_bsz: int = field(default=1, metadata={"help": "The batch size for lora ga"}) loraga_init_iters: int = field(default=4, metadata={"help": "The number of init iterations for lora ga"}) loraga_stable_gamma: int = field(default=32, metadata={"help": "Lora Ga stable gamma"}) lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+"}) diff --git a/paddlenlp/trl/llm_utils.py b/paddlenlp/trl/llm_utils.py index 23b03f9783c7..f10f75c2ccc4 100644 --- a/paddlenlp/trl/llm_utils.py +++ b/paddlenlp/trl/llm_utils.py @@ -25,12 +25,13 @@ import paddle.distributed.fleet.base.topology as tp import paddle.incubate.multiprocessing as mp from paddle.distributed import fleet +from paddle.io import DataLoader, DistributedBatchSampler from sklearn.metrics import accuracy_score from paddlenlp.datasets import ZeroPaddingIterableDataset from paddlenlp.generation import GenerationConfig from paddlenlp.trainer import TrainerCallback -from paddlenlp.trainer.trainer_utils import IterableDatasetShard +from paddlenlp.trainer.trainer_utils import IterableDatasetShard, ShardingOption from paddlenlp.transformers import ( AutoTokenizer, ChatGLMv2Tokenizer, @@ -38,6 +39,7 @@ PretrainedConfig, Qwen2ForCausalLMPipe, ) +from paddlenlp.transformers.model_utils import PretrainedModel from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer from paddlenlp.utils.log import logger @@ -744,41 +746,206 @@ def get_eos_token_id( return list(eos_token_ids_dict.values()) -def estimate_gradient(model, train_ds, data_collator, world_size=1, batch_size=1, iters=4, tokenizer=None): +def wrap_loraga_model(model, training_args): + sharding = None + if len(training_args.sharding) > 0: + if training_args.local_rank == -1: + raise ValueError("Using sharding only works in distributed training.") + sharding = True + + in_pipeline_parallel_mode = training_args.pipeline_parallel_degree > 1 + in_sharding_parallel_mode = sharding is not None + in_tensor_parallel_mode = training_args.tensor_parallel_degree > 1 + in_sep_parallel_mode = training_args.sep_parallel_degree > 1 + in_cp_parallel_mode = training_args.context_parallel_degree > 1 + + # Multi-gpu training + if training_args.world_size > 1 and (not training_args.use_hybrid_parallel): + # MOE use DDP to broadcaset parameters. + ddp_kwargs = {} + if training_args.ddp_find_unused_parameters is not None: + ddp_kwargs["find_unused_parameters"] = training_args.ddp_find_unused_parameters + elif isinstance(model, PretrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + ddp_kwargs["find_unused_parameters"] = not any( + hasattr(m, "enable_recompute") and m.enable_recompute for m in model.sublayers(include_self=True) + ) + else: + ddp_kwargs["find_unused_parameters"] = True + model = paddle.DataParallel(model, **ddp_kwargs) + + # No pipeline mode, sharding only + if not in_pipeline_parallel_mode and in_sharding_parallel_mode: + # Sharded DDP! + if training_args.tensor_parallel_degree > 1: + hcg = fleet.get_hybrid_communicate_group() + assert ( + ShardingOption.SHARD_GRAD_OP in training_args.sharding + or ShardingOption.SHARD_OP in training_args.sharding + ), "Only support tensor parallel + sharding stage1/stage2 hybrid parallel now." + model = paddle.distributed.fleet.meta_parallel.TensorParallel(model, hcg, strategy=None) + if ShardingOption.SHARD_OP in training_args.sharding: + model = fleet.distributed_model(model) + + if ( + not in_pipeline_parallel_mode + and not in_sharding_parallel_mode + and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode) + ): + model = fleet.distributed_model(model) + + return model + + +def get_loraga_dataloader(train_dataset, data_collator, training_args): + from paddlenlp.data import DistDataLoader + + def _is_iterable_dataset(dataset): + return isinstance(dataset, paddle.io.IterableDataset) + + def _is_iterable_dataset_distributed(dataset): + # For distributed dataloaer. + is_iterable_dataset_tensor = paddle.to_tensor(is_iterable_dataset(dataset)).astype("int32").reshape([1]) + if dist.get_world_size() > 1: + dist.all_reduce(is_iterable_dataset_tensor, op=dist.ReduceOp.MAX) + if is_iterable_dataset_tensor.item() == 1: + return True + return False + + if training_args.distributed_dataloader: + is_iterable_dataset = _is_iterable_dataset_distributed(train_dataset) + else: + is_iterable_dataset = _is_iterable_dataset(train_dataset) + + # if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset): + # train_dataset = self._remove_unused_columns(train_dataset, description="training") + _DataLoader = DistDataLoader if training_args.distributed_dataloader else DataLoader + + if is_iterable_dataset: # For iterable dataset + if training_args.dataset_world_size > 1 and train_dataset is not None: + train_dataset = IterableDatasetShard( + train_dataset, + batch_size=training_args.per_device_train_batch_size, + drop_last=training_args.dataloader_drop_last, + num_processes=training_args.dataset_world_size, + process_index=training_args.dataset_rank, + ) + + if training_args.distributed_dataloader: + logger.info("Training using DistDataLoader.") + additional_configs = {"is_iterable_dataset": True} + else: + additional_configs = {} + return _DataLoader( + train_dataset, + batch_size=training_args.per_device_train_batch_size, + collate_fn=data_collator, + num_workers=training_args.dataloader_num_workers, + **additional_configs, + ) + else: + train_sampler = get_loraga_train_sampler(train_dataset, training_args) + if training_args.distributed_dataloader: + logger.info("Training using DistDataLoader.") + return _DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=data_collator, + num_workers=training_args.dataloader_num_workers, + ) + + +def get_loraga_train_sampler(train_dataset, training_args) -> Optional[paddle.io.Sampler]: + if training_args.world_size <= 1: + return paddle.io.BatchSampler( + dataset=train_dataset, + shuffle=True, + batch_size=training_args.per_device_train_batch_size, + drop_last=training_args.dataloader_drop_last, + ) + + return DistributedBatchSampler( + train_dataset, + batch_size=training_args.per_device_train_batch_size, + shuffle=True, + num_replicas=training_args.dataset_world_size, + rank=training_args.dataset_rank, + drop_last=training_args.dataloader_drop_last, + ) + + +def estimate_gradient(model, train_ds, data_collator, training_args, loraga_init_iters=32): """Estimate the gradient of the model on the given dataset""" + import time + + start_time = time.time() logger.info("Estimating gradient for LoraGA") + split_mappings = model._get_tensor_parallel_mappings(config=model.config, is_split=False) + model = wrap_loraga_model(model, training_args) model.train() - named_grads = {} - dataloader = DataLoader(train_ds, collate_fn=data_collator, batch_size=batch_size, shuffle=False) - num_batch = 0 + gradient_dict = {} + logger.info(f"Initilization iterions for LoraGA: {loraga_init_iters}") + dataloader = get_loraga_dataloader(train_ds, data_collator, training_args) + iters = 0 for batch in dataloader: - num_batch += 1 + iters += 1 batch = {k: paddle.to_tensor(v) for k, v in batch.items()} + # Do not support pipeline parallel by now loss, logits = model(**batch) + # log_memory_usage() loss.backward() - + # log_memory_usage() # Record gradients for grad_name, param in model.named_parameters(): - # if param.stop_gradient is False: - if param.stop_gradient is False and param.grad is not None: - if grad_name not in named_grads: - named_grads[grad_name] = param.grad.clone() + # 经过tp和sharding包裹后的模型可能以若干个_layer.开头,这里需要去掉 + grad_name = grad_name.split("_layers.")[-1] + if not param.stop_gradient and param.grad is not None: + if grad_name not in gradient_dict: + gradient_dict[grad_name] = param.grad.clone() else: - named_grads[grad_name] += param.grad + gradient_dict[grad_name] += param.grad + param.clear_gradient(False) # release gradient memory - param.clear_gradient() - if num_batch == iters: + if iters == loraga_init_iters: break - for grad_name, param in named_grads.items(): - named_grads[grad_name] /= num_batch - + for grad_name, param in gradient_dict.items(): + # 暂时不支持pp! + # tp + if training_args.tensor_parallel_degree > 1: + if grad_name.split("gpt.")[-1] in split_mappings: + # 有的模型可能不以gpt.开头? + merge_func = split_mappings[grad_name.split("gpt.")[-1]] + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + output_tensors = [] + dist.all_gather(output_tensors, gradient_dict[grad_name], group=model_parallel_group) + output_tensors = [t if len(t.shape) > 0 else t.reshape_([-1]) for t in output_tensors] + gradient_dict[grad_name] = paddle.to_tensor(merge_func(output_tensors)) + # sharding + if training_args.sharding_parallel_degree > 1: + hcg = fleet.get_hybrid_communicate_group() + sharding_parallel_group = hcg.get_sharding_parallel_group() + if sharding_parallel_group.nranks > 1: + dist.all_reduce(gradient_dict[grad_name], op=dist.ReduceOp.SUM, group=sharding_parallel_group) + gradient_dict[grad_name] /= sharding_parallel_group.nranks + # dp + if training_args.data_parallel_degree > 1: + hcg = fleet.get_hybrid_communicate_group() + data_parallel_group = hcg.get_data_parallel_group() + if data_parallel_group.nranks > 1: + dist.all_reduce(gradient_dict[grad_name], op=dist.ReduceOp.SUM, group=data_parallel_group) + gradient_dict[grad_name] /= data_parallel_group.nranks + gradient_dict[grad_name] /= loraga_init_iters paddle.device.cuda.empty_cache() - return named_grads + + logger.info("Gradient Approximation execution time: {} seconds".format(time.time() - start_time)) + return gradient_dict -def loraga_reinit(model, named_grads, stable_gamma, **kwargs): +def loraga_reinit(model, gradient_dict, stable_gamma, training_args, **kwargs): """Re-initialize the weights of the model using the estimated gradients""" from tqdm import tqdm @@ -787,19 +954,37 @@ def loraga_reinit(model, named_grads, stable_gamma, **kwargs): desc="Reinitializing Lora", total=len(list(model.named_sublayers())), ): - from paddlenlp.peft.lora.lora_layers import LoRALinear + from paddlenlp.peft.lora.lora_layers import ( + ColumnParallelLoRALinear, + ColumnSequenceParallelLoRALinear, + LoRALinear, + RowParallelLoRALinear, + RowSequenceParallelLoRALinear, + ) - if isinstance(module, LoRALinear): - loraga_reinit_modules(name, module, named_grads, stable_gamma, **kwargs) + lora_split_mapping = None + if ( + isinstance(module, LoRALinear) + or isinstance(module, RowSequenceParallelLoRALinear) + or isinstance(module, ColumnSequenceParallelLoRALinear) + or isinstance(module, RowParallelLoRALinear) + or isinstance(module, ColumnParallelLoRALinear) + ): + is_tp = training_args.tensor_parallel_degree > 1 + if is_tp: + lora_split_mapping = model._get_tensor_parallel_mappings(model.config) + loraga_reinit_modules(name, module, gradient_dict, stable_gamma, is_tp, lora_split_mapping, **kwargs) -def loraga_reinit_modules(name, module, named_grads, stable_gamma, **kwargs): +def loraga_reinit_modules(name, module, gradient_dict, stable_gamma, is_tp=False, lora_split_mapping=None, **kwargs): with paddle.no_grad(): lora_r = module.r grad_name = ".".join(name.split(".")[1:]) + ".weight" - grads = named_grads[grad_name] + loraA_name = ".".join(name.split(".")[1:]) + ".lora_A" + loraB_name = ".".join(name.split(".")[1:]) + ".lora_B" + grads = gradient_dict[grad_name] - U, S, V = paddle.linalg.svd_lowrank(grads.cuda().astype("float32"), q=4 * lora_r, niter=4) + U, S, V = paddle.linalg.svd_lowrank(grads.astype("float32"), q=4 * lora_r, niter=4) V = V.T A = U[:, lora_r : 2 * lora_r] @@ -812,6 +997,13 @@ def loraga_reinit_modules(name, module, named_grads, stable_gamma, **kwargs): else: A = A / module.scaling B = B / module.scaling + if is_tp: + if module.lora_A.is_distributed and lora_split_mapping: + split_function = lora_split_mapping[loraA_name] + A = paddle.to_tensor(split_function(A)) + if module.lora_B.is_distributed and lora_split_mapping: + split_function = lora_split_mapping[loraB_name] + B = paddle.to_tensor(split_function(B)) module.lora_A.set_value(A.astype(module.lora_A.dtype)) module.lora_B.set_value(B.astype(module.lora_B.dtype)) offset = module.lora_A @ module.lora_B From 6a8d175ceecaec09a0fc7e8c23a62071f3d4f244 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Mon, 9 Dec 2024 13:29:46 +0000 Subject: [PATCH 4/9] support resume training and gradient offlaod hook --- paddlenlp/peft/lora/lora_config.py | 4 +- paddlenlp/peft/lora/lora_model.py | 54 ++- paddlenlp/peft/lora/loraga_utils.py | 416 ++++++++++++++++++ .../load_save_single_card.py | 2 +- .../unified_checkpoint/unified_checkpoint.py | 2 +- paddlenlp/trainer/unified_checkpoint/utils.py | 6 +- 6 files changed, 472 insertions(+), 12 deletions(-) create mode 100644 paddlenlp/peft/lora/loraga_utils.py diff --git a/paddlenlp/peft/lora/lora_config.py b/paddlenlp/peft/lora/lora_config.py index a47a3741a2b1..3f938e93a188 100644 --- a/paddlenlp/peft/lora/lora_config.py +++ b/paddlenlp/peft/lora/lora_config.py @@ -76,9 +76,7 @@ class LoRAConfig: do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"}) rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"}) - loraga: bool = field(default=False, metadata={"help": "Whether to use LoRA-GA:https://arxiv.org/pdf/2407.05000"}) - loraga_init_iters: int = field(default=4, metadata={"help": "The number of init iterations for lora ga"}) - loraga_stable_gamma: int = field(default=32, metadata={"help": "Lora Ga stable gamma"}) + loraga: bool = field(default=False, metadata={"help": "Whether to LoRA-GA"}) lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+"}) base_model_name_or_path: Optional[str] = field( default=None, metadata={"help": "The name of the base model to use."} diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 3f0453b7bc35..f52e662918eb 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -163,6 +163,8 @@ def __init__(self, model, lora_config: LoRAConfig) -> None: ) self.forward = self.model.forward + if lora_config.loraga: + self.loraga_init_dict = {} logger.info("Mark only lora and trainable_module as trainable.") self.mark_only_lora_as_trainable() @@ -269,7 +271,7 @@ def from_pretrained(cls, model, lora_path, **kwargs): tp_actions if pre_tensor_parallel_split else None, expected_keys, ) - error_msgs += _load_state_dict_into_model(lora_model.model, state_dict, "") + error_msgs += _load_state_dict_into_model(lora_model, state_dict, "") del state_dict gc.collect() @@ -319,6 +321,41 @@ def set_state_dict(self, state_dict): warnings.filterwarnings( action="ignore", message=".*Skip loading for.*", category=Warning, lineno=0, append=False ) + + model_state_dict = self.model.state_dict() + if self.lora_config.loraga: + + def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict): + if isinstance(concat_tensor, np.ndarray): + final_lora, init_lora = np.split(concat_tensor, 2, axis=axis) + init_lora = paddle.to_tensor(init_lora) + else: + final_lora, init_lora = paddle.split(concat_tensor, 2, axis=axis) + init_dict[name] = init_lora + state_dict[name] = final_lora + return final_lora, init_lora + + for name in state_dict.keys(): + if "lora_A" in name: + concat_lora_A = state_dict[name] + final_loraA, init_loraA = process_split_and_assign( + name, concat_lora_A, axis=1, init_dict=self.loraga_init_dict, state_dict=state_dict + ) + + loraB_name = name.replace("lora_A", "lora_B") + concat_lora_B = state_dict[loraB_name] + final_loraB, init_loraB = process_split_and_assign( + loraB_name, concat_lora_B, axis=0, init_dict=self.loraga_init_dict, state_dict=state_dict + ) + + base_name = name.replace("lora_A", "weight") + + # Reinit base model + offset = init_loraA.cuda() @ init_loraB.cuda() + ori_weight = model_state_dict[base_name] + model_state_dict[base_name].set_value(ori_weight - self.lora_config.scaling * offset) + del model_state_dict + gc.collect() self.model.set_state_dict(state_dict) logger.info("Load lora weight successfully") @@ -400,8 +437,9 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal lora_config_to_save = LoRAConfig(**self.lora_config.to_dict()) + trainable_state_dict = self.get_trainable_state_dict(concat_init_lora=lora_config_to_save.loraga) + if merge_tensor_parallel and lora_config_to_save.tensor_parallel_degree > 1: - trainable_state_dict = self.get_trainable_state_dict() trainable_state_dict = self._merge_trainable_tensor_parallel(trainable_state_dict) if not is_main_process: logger.info("Saving with merge_tensor_parallel, tensor_parallel_rank > 0 don't need save") @@ -410,7 +448,6 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal variant = "_".join([x for x in variant.split("_") if "tp" not in x]) lora_config_to_save.tensor_parallel_degree = -1 else: - trainable_state_dict = self.get_trainable_state_dict() if lora_config_to_save.tensor_parallel_degree > 1: if variant is None: variant = weight_name_suffix() @@ -641,12 +678,19 @@ def _find_and_restore_module(self, module_name): original_module.bias = module.bias setattr(parent_module, attribute_chain[-1], original_module) - def get_trainable_state_dict(self): + def get_trainable_state_dict(self, concat_init_lora=False): trainable_state_dict = OrderedDict() for name, weight in self.model.state_dict().items(): # get lora parameter & QAT scale parameter if not weight.stop_gradient or "activation_quanter" in name or "weight_quanter" in name: - trainable_state_dict[name] = weight + if concat_init_lora: + if "lora_A" in name: + trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=1) + else: + trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=0) + else: + trainable_state_dict[name] = weight + return trainable_state_dict def print_trainable_parameters(self) -> None: diff --git a/paddlenlp/peft/lora/loraga_utils.py b/paddlenlp/peft/lora/loraga_utils.py new file mode 100644 index 000000000000..b9f5f6871aee --- /dev/null +++ b/paddlenlp/peft/lora/loraga_utils.py @@ -0,0 +1,416 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet +from paddle.io import DataLoader, DistributedBatchSampler + +from paddlenlp.peft.lora.lora_layers import ( + ColumnParallelLoRALinear, + ColumnSequenceParallelLoRALinear, + LoRALinear, + RowParallelLoRALinear, + RowSequenceParallelLoRALinear, +) +from paddlenlp.trainer.trainer_utils import IterableDatasetShard +from paddlenlp.transformers.model_utils import unwrap_model +from paddlenlp.utils.log import logger + + +def wrap_loraga_model(model, training_args): + """Wrap Model with distributed strategies, support tp, dp, sharding""" + + from paddlenlp.trainer.trainer_utils import ShardingOption + from paddlenlp.transformers.model_utils import PretrainedModel + + sharding = None + if len(training_args.sharding) > 0: + if training_args.local_rank == -1: + raise ValueError("Using sharding only works in distributed training.") + sharding = True + + in_pipeline_parallel_mode = training_args.pipeline_parallel_degree > 1 + in_sharding_parallel_mode = sharding is not None + in_tensor_parallel_mode = training_args.tensor_parallel_degree > 1 + in_sep_parallel_mode = training_args.sep_parallel_degree > 1 + in_cp_parallel_mode = training_args.context_parallel_degree > 1 + + # Multi-gpu training + if training_args.world_size > 1 and (not training_args.use_hybrid_parallel): + # MOE use DDP to broadcaset parameters. + ddp_kwargs = {} + if training_args.ddp_find_unused_parameters is not None: + ddp_kwargs["find_unused_parameters"] = training_args.ddp_find_unused_parameters + elif isinstance(model, PretrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + ddp_kwargs["find_unused_parameters"] = not any( + hasattr(m, "enable_recompute") and m.enable_recompute for m in model.sublayers(include_self=True) + ) + else: + ddp_kwargs["find_unused_parameters"] = True + model = paddle.DataParallel(model, **ddp_kwargs) + + # No pipeline mode, sharding only + if not in_pipeline_parallel_mode and in_sharding_parallel_mode: + # Sharded DDP! + if training_args.tensor_parallel_degree > 1: + hcg = fleet.get_hybrid_communicate_group() + assert ( + ShardingOption.SHARD_GRAD_OP in training_args.sharding + or ShardingOption.SHARD_OP in training_args.sharding + ), "Only support tensor parallel + sharding stage1/stage2 hybrid parallel now." + model = paddle.distributed.fleet.meta_parallel.TensorParallel(model, hcg, strategy=None) + if ShardingOption.SHARD_OP in training_args.sharding: + model = fleet.distributed_model(model) + + if ( + not in_pipeline_parallel_mode + and not in_sharding_parallel_mode + and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode) + ): + model = fleet.distributed_model(model) + + return model + + +def get_loraga_dataloader(train_dataset, data_collator, training_args): + from paddlenlp.data import DistDataLoader + + def is_iterable_dataset(dataset): + return isinstance(dataset, paddle.io.IterableDataset) + + def is_iterable_dataset_distributed(dataset): + # For distributed dataloaer. + is_iterable_dataset_tensor = paddle.to_tensor(is_iterable_dataset(dataset)).astype("int32").reshape([1]) + if dist.get_world_size() > 1: + dist.all_reduce(is_iterable_dataset_tensor, op=dist.ReduceOp.MAX) + if is_iterable_dataset_tensor.item() == 1: + return True + return False + + if training_args.distributed_dataloader: + iterable_dataset = is_iterable_dataset_distributed(train_dataset) + else: + iterable_dataset = is_iterable_dataset(train_dataset) + + # if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset): + # train_dataset = self._remove_unused_columns(train_dataset, description="training") + _DataLoader = DistDataLoader if training_args.distributed_dataloader else DataLoader + + if iterable_dataset: # For iterable dataset + if training_args.dataset_world_size > 1 and train_dataset is not None: + train_dataset = IterableDatasetShard( + train_dataset, + batch_size=training_args.per_device_train_batch_size, + drop_last=training_args.dataloader_drop_last, + num_processes=training_args.dataset_world_size, + process_index=training_args.dataset_rank, + ) + + if training_args.distributed_dataloader: + logger.info("Training using DistDataLoader.") + additional_configs = {"is_iterable_dataset": True} + else: + additional_configs = {} + return _DataLoader( + train_dataset, + batch_size=training_args.per_device_train_batch_size, + collate_fn=data_collator, + num_workers=training_args.dataloader_num_workers, + **additional_configs, + ) + else: + train_sampler = get_loraga_train_sampler(train_dataset, training_args) + if training_args.distributed_dataloader: + logger.info("Training using DistDataLoader.") + return _DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=data_collator, + num_workers=training_args.dataloader_num_workers, + ) + + +def get_loraga_train_sampler(train_dataset, training_args) -> Optional[paddle.io.Sampler]: + if training_args.world_size <= 1: + return paddle.io.BatchSampler( + dataset=train_dataset, + shuffle=True, + batch_size=training_args.per_device_train_batch_size, + drop_last=training_args.dataloader_drop_last, + ) + + return DistributedBatchSampler( + train_dataset, + batch_size=training_args.per_device_train_batch_size, + shuffle=True, + num_replicas=training_args.dataset_world_size, + rank=training_args.dataset_rank, + drop_last=training_args.dataloader_drop_last, + ) + + +def estimate_gradient(model, train_ds, data_collator, training_args, loraga_init_iters=32, gradient_offload=False): + """Estimate the gradient of the model on the given dataset""" + gradient_dict = {} + logger.info("Estimating gradient for LoraGA.") + + model = wrap_loraga_model(model, training_args) + model.train() + + logger.info(f"Initialization iterations for LoraGA: {loraga_init_iters}") + dataloader = get_loraga_dataloader(train_ds, data_collator, training_args) + iters = 0 + + with GradientOffloadHookContext( + model=model, + gradient_dict=gradient_dict, + local_rank=training_args.local_rank, + loraga_init_iters=loraga_init_iters, + gradient_offload=gradient_offload, + ): + for batch in dataloader: + iters += 1 + batch = {k: paddle.to_tensor(v) for k, v in batch.items()} + + # Pipeline parallel not supported currently + loss, logits = model(**batch) + loss.backward() + + if iters == loraga_init_iters: + break + + return gradient_dict + + +def get_module_gradient( + grad_name, + base_model_prefix, + gradient_dict, + base_model_split_mappings, + tp_degree, + sharding_degree, + dp_degree, + local_rank, +): + rank_suffix = "_" + str(local_rank) + local_grad_name = ".".join(grad_name.split(".")[1:]) + ".weight" + rank_suffix + gradient = gradient_dict.pop(local_grad_name).cuda() + if tp_degree > 1: + # remove prefix and suffix + model_split_key = local_grad_name.split(base_model_prefix)[-1].rsplit(rank_suffix, 1)[0] + if model_split_key in base_model_split_mappings: + merge_func = base_model_split_mappings[model_split_key] + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + output_tensors = [] + dist.all_gather(output_tensors, gradient, group=model_parallel_group) + + output_tensors = [t if len(t.shape) > 0 else t.reshape_([-1]) for t in output_tensors] + gradient = merge_func(output_tensors).cuda() + + # sharding + if sharding_degree > 1: + hcg = fleet.get_hybrid_communicate_group() + sharding_parallel_group = hcg.get_sharding_parallel_group() + if sharding_parallel_group.nranks > 1: + + dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=sharding_parallel_group) + gradient /= sharding_parallel_group.nranks + # dp + if dp_degree > 1: + hcg = fleet.get_hybrid_communicate_group() + data_parallel_group = hcg.get_data_parallel_group() + if data_parallel_group.nranks > 1: + dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=data_parallel_group) + gradient /= data_parallel_group.nranks + return gradient + + +def loraga_svd_reinit(model, gradient_dict, base_model_split_mappings, stable_gamma, training_args, **kwargs) -> None: + """ + If Loraga has already been initialized, directly modify the base model weights. + Otherwise, reinitialize and save the initialized model. + + Args: + model (Any): The model to reinitialize. + gradient_dict (Dict[str, Any]): Dictionary containing gradients. + model_split_mappings (Any): Mappings for model tensor parallelism. + stable_gamma (Any): Stable gamma parameter for Loraga. + training_args (Any): Training arguments. + **kwargs: Additional keyword arguments. + """ + + lora_split_mapping = None + tensor_parallel_degree = training_args.tensor_parallel_degree + in_tensor_parallel_mode = tensor_parallel_degree > 1 + + base_model_prefix = unwrap_model(model).base_model_prefix + "." + if in_tensor_parallel_mode: + lora_split_mapping = model._get_tensor_parallel_mappings(model.config) + loraga_init_dict = {} + for name, module in model.named_sublayers(): + if isinstance( + module, + ( + LoRALinear, + RowSequenceParallelLoRALinear, + ColumnSequenceParallelLoRALinear, + RowParallelLoRALinear, + ColumnParallelLoRALinear, + ), + ): + # gather gradient if in tensor parallel mode, average gradient if in data parallel mode + module_gradient = get_module_gradient( + name, + base_model_prefix, + gradient_dict, + base_model_split_mappings, + training_args.tensor_parallel_degree, + training_args.sharding_parallel_degree, + training_args.data_parallel_degree, + training_args.local_rank, + ) + # perform SVD to reinit base model weight and lora adapter weight + loraga_svd_module( + name, + module, + module_gradient, + stable_gamma, + loraga_init_dict, + in_tensor_parallel_mode, + lora_split_mapping, + **kwargs, + ) + model.loraga_init_dict = loraga_init_dict + + +def loraga_svd_module( + name, + module, + grads, + stable_gamma, + loraga_init_dict, + in_tensor_parallel_mode=False, + lora_split_mapping=None, + **kwargs +): + with paddle.no_grad(): + lora_r = module.r + + loraA_name = ".".join(name.split(".")[1:]) + ".lora_A" + loraB_name = ".".join(name.split(".")[1:]) + ".lora_B" + + U, S, V = paddle.linalg.svd_lowrank(grads.astype("float32"), q=4 * lora_r, niter=4) + + V = V.T + # get new low rank adapter after SVD + A = U[:, lora_r : 2 * lora_r] + B = V[:lora_r, :] + m, n = grads.shape # m: feature_out, n: feature_in + # If stable_gamma is not -1, scale the matrices A and B by the square root of the stable_gamma + if stable_gamma != -1: + A = A * m**0.25 / stable_gamma**0.5 + B = B * m**0.25 / stable_gamma**0.5 + else: + A = A / module.scaling + B = B / module.scaling + + if in_tensor_parallel_mode: + # split lora adapter weight if in tensor parallel mode + if module.lora_A.is_distributed and lora_split_mapping is not None: + split_function = lora_split_mapping[loraA_name] + A = paddle.to_tensor(split_function(A)) + if module.lora_B.is_distributed and lora_split_mapping is not None: + split_function = lora_split_mapping[loraB_name] + B = paddle.to_tensor(split_function(B)) + A = A.astype(module.lora_A.dtype) + B = B.astype(module.lora_B.dtype) + loraga_init_dict[loraA_name] = A + loraga_init_dict[loraB_name] = B + # reinit lora adapter weight + module.lora_A.set_value(A) + module.lora_B.set_value(B) + + offset = module.lora_A @ module.lora_B + # reinit base model weight + module.weight.data -= module.scaling * offset + + +def set_hook_enable(value=False): + global ENABLE_HOOK + ENABLE_HOOK = value + + +def get_hook_enable(): + global ENABLE_HOOK + return ENABLE_HOOK + + +class GradientOffloadHookContext: + def __init__( + self, + model, + gradient_dict: dict, + local_rank: int = 0, + loraga_init_iters: int = 4, + gradient_offload: bool = False, + *args, + **kwargs, + ): + """Offload gradient to cpu""" + self.model = model + self.gradient_dict = gradient_dict + self.local_rank = local_rank + self.loraga_init_iters = loraga_init_iters + self.gradient_offload = gradient_offload + + def __enter__(self): + set_hook_enable(True) + self.register_gradient_hook() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + set_hook_enable(False) + + def register_gradient_hook(self): + for grad_name, param in self.model.named_parameters(): + param._register_backward_hook( + self.get_record_gradient_hook(self.model, self.gradient_dict, grad_name, param) + ) + + def get_record_gradient_hook(self, model, gradient_dict, grad_name, param): + def record_gradient_hook(*_): + if get_hook_enable(): + grad = param.grad + local_grad_name = grad_name.split("_layers.")[-1] + "_" + str(self.local_rank) + if not param.stop_gradient and grad is not None: + if local_grad_name not in gradient_dict: + if self.gradient_offload: + gradient_dict[local_grad_name] = (grad / self.loraga_init_iters).cpu() + else: + gradient_dict[local_grad_name] = grad.clone() / self.loraga_init_iters + else: + if self.gradient_offload: + new_grad = gradient_dict[local_grad_name].cuda() + grad / self.loraga_init_iters + gradient_dict[local_grad_name] = new_grad.cpu() + else: + gradient_dict[local_grad_name] += grad / self.loraga_init_iters + param.clear_gradient(False) # release gradient memory + + return record_gradient_hook diff --git a/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py b/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py index 30287981d438..47689939617b 100644 --- a/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py +++ b/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py @@ -67,7 +67,7 @@ def save_file_sync(state_dict, path): def save_single_card_checkpoint(model_to_save, output_dir): """Save checkpoint for non-distributed environment.""" - state_dict = get_expected_state_dict(model_to_save) + state_dict = get_expected_state_dict(model_to_save, concat_additional_adapter=True) if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): weight_filename = "peft_model-00001-of-00001.safetensors" index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index f8875cc89262..37bb8facea05 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -507,7 +507,7 @@ def unified_checkpoint_into_shards( paddle.device.cuda.empty_cache() assert hasattr(model_to_save, "config") - state_dict = get_expected_state_dict(model_to_save) + state_dict = get_expected_state_dict(model_to_save, concat_additional_adapter=True) all_filter_keys = filter_params(model_to_save, state_dict, args) config_to_save = copy.deepcopy(model_to_save.config) diff --git a/paddlenlp/trainer/unified_checkpoint/utils.py b/paddlenlp/trainer/unified_checkpoint/utils.py index 58e425ca987d..6cc57c148a08 100644 --- a/paddlenlp/trainer/unified_checkpoint/utils.py +++ b/paddlenlp/trainer/unified_checkpoint/utils.py @@ -202,7 +202,7 @@ def mapping_optimizer_tp_actions(tp_actions, optimizer_loaded_keys): return new_actions -def get_expected_state_dict(model_to_save): +def get_expected_state_dict(model_to_save, **kwargs): """ Get trainable state_dict of model_to_save. """ @@ -220,7 +220,9 @@ def get_expected_state_dict(model_to_save): if key in state_dict: state_dict.pop(key) elif isinstance(model_to_save, LoRAModel): - state_dict = model_to_save.get_trainable_state_dict() + concat_additional_adapter = kwargs.get("concat_additional_adapter", False) + concat_init_lora = model_to_save.lora_config.loraga and concat_additional_adapter + state_dict = model_to_save.get_trainable_state_dict(concat_init_lora=concat_init_lora) elif isinstance(model_to_save, PrefixModelForCausalLM): state_dict = model_to_save.prefix_encoder.state_dict() From 0421446e41143b3c9b2e63c55ba4aa2486eaa260 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Mon, 9 Dec 2024 13:40:16 +0000 Subject: [PATCH 5/9] remove trl/llm_utils.py --- paddlenlp/trl/llm_utils.py | 268 +------------------------------------ 1 file changed, 1 insertion(+), 267 deletions(-) diff --git a/paddlenlp/trl/llm_utils.py b/paddlenlp/trl/llm_utils.py index 7b35df0d12ee..ee92735147af 100644 --- a/paddlenlp/trl/llm_utils.py +++ b/paddlenlp/trl/llm_utils.py @@ -25,13 +25,12 @@ import paddle.distributed.fleet.base.topology as tp import paddle.incubate.multiprocessing as mp from paddle.distributed import fleet -from paddle.io import DataLoader, DistributedBatchSampler from sklearn.metrics import accuracy_score from paddlenlp.datasets import ZeroPaddingIterableDataset from paddlenlp.generation import GenerationConfig from paddlenlp.trainer import TrainerCallback -from paddlenlp.trainer.trainer_utils import IterableDatasetShard, ShardingOption +from paddlenlp.trainer.trainer_utils import IterableDatasetShard from paddlenlp.transformers import ( AutoTokenizer, ChatGLMv2Tokenizer, @@ -39,7 +38,6 @@ PretrainedConfig, Qwen2ForCausalLMPipe, ) -from paddlenlp.transformers.model_utils import PretrainedModel from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer from paddlenlp.utils.log import logger @@ -760,267 +758,3 @@ def get_eos_token_id( eos_token_ids_dict = {str(item): item for item in eos_token_ids} return list(eos_token_ids_dict.values()) - - -def wrap_loraga_model(model, training_args): - sharding = None - if len(training_args.sharding) > 0: - if training_args.local_rank == -1: - raise ValueError("Using sharding only works in distributed training.") - sharding = True - - in_pipeline_parallel_mode = training_args.pipeline_parallel_degree > 1 - in_sharding_parallel_mode = sharding is not None - in_tensor_parallel_mode = training_args.tensor_parallel_degree > 1 - in_sep_parallel_mode = training_args.sep_parallel_degree > 1 - in_cp_parallel_mode = training_args.context_parallel_degree > 1 - - # Multi-gpu training - if training_args.world_size > 1 and (not training_args.use_hybrid_parallel): - # MOE use DDP to broadcaset parameters. - ddp_kwargs = {} - if training_args.ddp_find_unused_parameters is not None: - ddp_kwargs["find_unused_parameters"] = training_args.ddp_find_unused_parameters - elif isinstance(model, PretrainedModel): - # find_unused_parameters breaks checkpointing as per - # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 - ddp_kwargs["find_unused_parameters"] = not any( - hasattr(m, "enable_recompute") and m.enable_recompute for m in model.sublayers(include_self=True) - ) - else: - ddp_kwargs["find_unused_parameters"] = True - model = paddle.DataParallel(model, **ddp_kwargs) - - # No pipeline mode, sharding only - if not in_pipeline_parallel_mode and in_sharding_parallel_mode: - # Sharded DDP! - if training_args.tensor_parallel_degree > 1: - hcg = fleet.get_hybrid_communicate_group() - assert ( - ShardingOption.SHARD_GRAD_OP in training_args.sharding - or ShardingOption.SHARD_OP in training_args.sharding - ), "Only support tensor parallel + sharding stage1/stage2 hybrid parallel now." - model = paddle.distributed.fleet.meta_parallel.TensorParallel(model, hcg, strategy=None) - if ShardingOption.SHARD_OP in training_args.sharding: - model = fleet.distributed_model(model) - - if ( - not in_pipeline_parallel_mode - and not in_sharding_parallel_mode - and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode) - ): - model = fleet.distributed_model(model) - - return model - - -def get_loraga_dataloader(train_dataset, data_collator, training_args): - from paddlenlp.data import DistDataLoader - - def _is_iterable_dataset(dataset): - return isinstance(dataset, paddle.io.IterableDataset) - - def _is_iterable_dataset_distributed(dataset): - # For distributed dataloaer. - is_iterable_dataset_tensor = paddle.to_tensor(is_iterable_dataset(dataset)).astype("int32").reshape([1]) - if dist.get_world_size() > 1: - dist.all_reduce(is_iterable_dataset_tensor, op=dist.ReduceOp.MAX) - if is_iterable_dataset_tensor.item() == 1: - return True - return False - - if training_args.distributed_dataloader: - is_iterable_dataset = _is_iterable_dataset_distributed(train_dataset) - else: - is_iterable_dataset = _is_iterable_dataset(train_dataset) - - # if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset): - # train_dataset = self._remove_unused_columns(train_dataset, description="training") - _DataLoader = DistDataLoader if training_args.distributed_dataloader else DataLoader - - if is_iterable_dataset: # For iterable dataset - if training_args.dataset_world_size > 1 and train_dataset is not None: - train_dataset = IterableDatasetShard( - train_dataset, - batch_size=training_args.per_device_train_batch_size, - drop_last=training_args.dataloader_drop_last, - num_processes=training_args.dataset_world_size, - process_index=training_args.dataset_rank, - ) - - if training_args.distributed_dataloader: - logger.info("Training using DistDataLoader.") - additional_configs = {"is_iterable_dataset": True} - else: - additional_configs = {} - return _DataLoader( - train_dataset, - batch_size=training_args.per_device_train_batch_size, - collate_fn=data_collator, - num_workers=training_args.dataloader_num_workers, - **additional_configs, - ) - else: - train_sampler = get_loraga_train_sampler(train_dataset, training_args) - if training_args.distributed_dataloader: - logger.info("Training using DistDataLoader.") - return _DataLoader( - train_dataset, - batch_sampler=train_sampler, - collate_fn=data_collator, - num_workers=training_args.dataloader_num_workers, - ) - - -def get_loraga_train_sampler(train_dataset, training_args) -> Optional[paddle.io.Sampler]: - if training_args.world_size <= 1: - return paddle.io.BatchSampler( - dataset=train_dataset, - shuffle=True, - batch_size=training_args.per_device_train_batch_size, - drop_last=training_args.dataloader_drop_last, - ) - - return DistributedBatchSampler( - train_dataset, - batch_size=training_args.per_device_train_batch_size, - shuffle=True, - num_replicas=training_args.dataset_world_size, - rank=training_args.dataset_rank, - drop_last=training_args.dataloader_drop_last, - ) - - -def estimate_gradient(model, train_ds, data_collator, training_args, loraga_init_iters=32): - """Estimate the gradient of the model on the given dataset""" - - import time - - start_time = time.time() - logger.info("Estimating gradient for LoraGA") - split_mappings = model._get_tensor_parallel_mappings(config=model.config, is_split=False) - model = wrap_loraga_model(model, training_args) - model.train() - gradient_dict = {} - logger.info(f"Initilization iterions for LoraGA: {loraga_init_iters}") - dataloader = get_loraga_dataloader(train_ds, data_collator, training_args) - iters = 0 - for batch in dataloader: - iters += 1 - batch = {k: paddle.to_tensor(v) for k, v in batch.items()} - # Do not support pipeline parallel by now - loss, logits = model(**batch) - # log_memory_usage() - loss.backward() - # log_memory_usage() - # Record gradients - for grad_name, param in model.named_parameters(): - # 经过tp和sharding包裹后的模型可能以若干个_layer.开头,这里需要去掉 - grad_name = grad_name.split("_layers.")[-1] - if not param.stop_gradient and param.grad is not None: - if grad_name not in gradient_dict: - gradient_dict[grad_name] = param.grad.clone() - else: - gradient_dict[grad_name] += param.grad - param.clear_gradient(False) # release gradient memory - - if iters == loraga_init_iters: - break - - for grad_name, param in gradient_dict.items(): - # 暂时不支持pp! - # tp - if training_args.tensor_parallel_degree > 1: - if grad_name.split("gpt.")[-1] in split_mappings: - # 有的模型可能不以gpt.开头? - merge_func = split_mappings[grad_name.split("gpt.")[-1]] - hcg = fleet.get_hybrid_communicate_group() - model_parallel_group = hcg.get_model_parallel_group() - output_tensors = [] - dist.all_gather(output_tensors, gradient_dict[grad_name], group=model_parallel_group) - output_tensors = [t if len(t.shape) > 0 else t.reshape_([-1]) for t in output_tensors] - gradient_dict[grad_name] = paddle.to_tensor(merge_func(output_tensors)) - # sharding - if training_args.sharding_parallel_degree > 1: - hcg = fleet.get_hybrid_communicate_group() - sharding_parallel_group = hcg.get_sharding_parallel_group() - if sharding_parallel_group.nranks > 1: - dist.all_reduce(gradient_dict[grad_name], op=dist.ReduceOp.SUM, group=sharding_parallel_group) - gradient_dict[grad_name] /= sharding_parallel_group.nranks - # dp - if training_args.data_parallel_degree > 1: - hcg = fleet.get_hybrid_communicate_group() - data_parallel_group = hcg.get_data_parallel_group() - if data_parallel_group.nranks > 1: - dist.all_reduce(gradient_dict[grad_name], op=dist.ReduceOp.SUM, group=data_parallel_group) - gradient_dict[grad_name] /= data_parallel_group.nranks - gradient_dict[grad_name] /= loraga_init_iters - paddle.device.cuda.empty_cache() - - logger.info("Gradient Approximation execution time: {} seconds".format(time.time() - start_time)) - return gradient_dict - - -def loraga_reinit(model, gradient_dict, stable_gamma, training_args, **kwargs): - """Re-initialize the weights of the model using the estimated gradients""" - from tqdm import tqdm - - for name, module in tqdm( - model.named_sublayers(), - desc="Reinitializing Lora", - total=len(list(model.named_sublayers())), - ): - from paddlenlp.peft.lora.lora_layers import ( - ColumnParallelLoRALinear, - ColumnSequenceParallelLoRALinear, - LoRALinear, - RowParallelLoRALinear, - RowSequenceParallelLoRALinear, - ) - - lora_split_mapping = None - if ( - isinstance(module, LoRALinear) - or isinstance(module, RowSequenceParallelLoRALinear) - or isinstance(module, ColumnSequenceParallelLoRALinear) - or isinstance(module, RowParallelLoRALinear) - or isinstance(module, ColumnParallelLoRALinear) - ): - is_tp = training_args.tensor_parallel_degree > 1 - if is_tp: - lora_split_mapping = model._get_tensor_parallel_mappings(model.config) - loraga_reinit_modules(name, module, gradient_dict, stable_gamma, is_tp, lora_split_mapping, **kwargs) - - -def loraga_reinit_modules(name, module, gradient_dict, stable_gamma, is_tp=False, lora_split_mapping=None, **kwargs): - with paddle.no_grad(): - lora_r = module.r - grad_name = ".".join(name.split(".")[1:]) + ".weight" - loraA_name = ".".join(name.split(".")[1:]) + ".lora_A" - loraB_name = ".".join(name.split(".")[1:]) + ".lora_B" - grads = gradient_dict[grad_name] - - U, S, V = paddle.linalg.svd_lowrank(grads.astype("float32"), q=4 * lora_r, niter=4) - - V = V.T - A = U[:, lora_r : 2 * lora_r] - B = V[:lora_r, :] - m, n = grads.shape # m: feature_out, n: feature_in - # If stable_gamma is not -1, scale the matrices A and B by the square root of the stable_gamma - if stable_gamma != -1: - A = A * m**0.25 / stable_gamma**0.5 - B = B * m**0.25 / stable_gamma**0.5 - else: - A = A / module.scaling - B = B / module.scaling - if is_tp: - if module.lora_A.is_distributed and lora_split_mapping: - split_function = lora_split_mapping[loraA_name] - A = paddle.to_tensor(split_function(A)) - if module.lora_B.is_distributed and lora_split_mapping: - split_function = lora_split_mapping[loraB_name] - B = paddle.to_tensor(split_function(B)) - module.lora_A.set_value(A.astype(module.lora_A.dtype)) - module.lora_B.set_value(B.astype(module.lora_B.dtype)) - offset = module.lora_A @ module.lora_B - module.weight.data -= module.scaling * offset From 468d5492f54b0ae727051a59716bdc51627202c7 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Tue, 10 Dec 2024 10:21:55 +0000 Subject: [PATCH 6/9] use loraga trainer --- paddlenlp/peft/lora/loraga_utils.py | 320 ++++++++++++---------------- 1 file changed, 139 insertions(+), 181 deletions(-) diff --git a/paddlenlp/peft/lora/loraga_utils.py b/paddlenlp/peft/lora/loraga_utils.py index b9f5f6871aee..a0fbbff7e6a6 100644 --- a/paddlenlp/peft/lora/loraga_utils.py +++ b/paddlenlp/peft/lora/loraga_utils.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - import paddle import paddle.distributed as dist from paddle.distributed import fleet -from paddle.io import DataLoader, DistributedBatchSampler +from paddlenlp.peft import LoRAModel from paddlenlp.peft.lora.lora_layers import ( ColumnParallelLoRALinear, ColumnSequenceParallelLoRALinear, @@ -26,176 +24,110 @@ RowParallelLoRALinear, RowSequenceParallelLoRALinear, ) -from paddlenlp.trainer.trainer_utils import IterableDatasetShard -from paddlenlp.transformers.model_utils import unwrap_model +from paddlenlp.trainer import Trainer, TrainingArguments +from paddlenlp.trainer.trainer_utils import ShardingOption +from paddlenlp.transformers.model_utils import PretrainedModel, unwrap_model from paddlenlp.utils.log import logger -def wrap_loraga_model(model, training_args): - """Wrap Model with distributed strategies, support tp, dp, sharding""" - - from paddlenlp.trainer.trainer_utils import ShardingOption - from paddlenlp.transformers.model_utils import PretrainedModel - - sharding = None - if len(training_args.sharding) > 0: - if training_args.local_rank == -1: - raise ValueError("Using sharding only works in distributed training.") - sharding = True - - in_pipeline_parallel_mode = training_args.pipeline_parallel_degree > 1 - in_sharding_parallel_mode = sharding is not None - in_tensor_parallel_mode = training_args.tensor_parallel_degree > 1 - in_sep_parallel_mode = training_args.sep_parallel_degree > 1 - in_cp_parallel_mode = training_args.context_parallel_degree > 1 - - # Multi-gpu training - if training_args.world_size > 1 and (not training_args.use_hybrid_parallel): - # MOE use DDP to broadcaset parameters. - ddp_kwargs = {} - if training_args.ddp_find_unused_parameters is not None: - ddp_kwargs["find_unused_parameters"] = training_args.ddp_find_unused_parameters - elif isinstance(model, PretrainedModel): - # find_unused_parameters breaks checkpointing as per - # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 - ddp_kwargs["find_unused_parameters"] = not any( - hasattr(m, "enable_recompute") and m.enable_recompute for m in model.sublayers(include_self=True) - ) - else: - ddp_kwargs["find_unused_parameters"] = True - model = paddle.DataParallel(model, **ddp_kwargs) +class LoRAGATrainer(Trainer): + """A Trainer class for Lora-GA gradient estimation.""" - # No pipeline mode, sharding only - if not in_pipeline_parallel_mode and in_sharding_parallel_mode: - # Sharded DDP! - if training_args.tensor_parallel_degree > 1: - hcg = fleet.get_hybrid_communicate_group() - assert ( - ShardingOption.SHARD_GRAD_OP in training_args.sharding - or ShardingOption.SHARD_OP in training_args.sharding - ), "Only support tensor parallel + sharding stage1/stage2 hybrid parallel now." - model = paddle.distributed.fleet.meta_parallel.TensorParallel(model, hcg, strategy=None) - if ShardingOption.SHARD_OP in training_args.sharding: - model = fleet.distributed_model(model) - - if ( - not in_pipeline_parallel_mode - and not in_sharding_parallel_mode - and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode) - ): - model = fleet.distributed_model(model) - - return model - - -def get_loraga_dataloader(train_dataset, data_collator, training_args): - from paddlenlp.data import DistDataLoader - - def is_iterable_dataset(dataset): - return isinstance(dataset, paddle.io.IterableDataset) - - def is_iterable_dataset_distributed(dataset): - # For distributed dataloaer. - is_iterable_dataset_tensor = paddle.to_tensor(is_iterable_dataset(dataset)).astype("int32").reshape([1]) - if dist.get_world_size() > 1: - dist.all_reduce(is_iterable_dataset_tensor, op=dist.ReduceOp.MAX) - if is_iterable_dataset_tensor.item() == 1: - return True - return False - - if training_args.distributed_dataloader: - iterable_dataset = is_iterable_dataset_distributed(train_dataset) - else: - iterable_dataset = is_iterable_dataset(train_dataset) - - # if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset): - # train_dataset = self._remove_unused_columns(train_dataset, description="training") - _DataLoader = DistDataLoader if training_args.distributed_dataloader else DataLoader - - if iterable_dataset: # For iterable dataset - if training_args.dataset_world_size > 1 and train_dataset is not None: - train_dataset = IterableDatasetShard( - train_dataset, - batch_size=training_args.per_device_train_batch_size, - drop_last=training_args.dataloader_drop_last, - num_processes=training_args.dataset_world_size, - process_index=training_args.dataset_rank, - ) + def __init__(self, loraga_init_iters: int, gradient_offload: bool, **kwargs): + """ + Initialize the Trainer class for Lora-GA gradient estimation. - if training_args.distributed_dataloader: - logger.info("Training using DistDataLoader.") - additional_configs = {"is_iterable_dataset": True} - else: - additional_configs = {} - return _DataLoader( - train_dataset, - batch_size=training_args.per_device_train_batch_size, - collate_fn=data_collator, - num_workers=training_args.dataloader_num_workers, - **additional_configs, - ) - else: - train_sampler = get_loraga_train_sampler(train_dataset, training_args) - if training_args.distributed_dataloader: - logger.info("Training using DistDataLoader.") - return _DataLoader( - train_dataset, - batch_sampler=train_sampler, - collate_fn=data_collator, - num_workers=training_args.dataloader_num_workers, - ) - - -def get_loraga_train_sampler(train_dataset, training_args) -> Optional[paddle.io.Sampler]: - if training_args.world_size <= 1: - return paddle.io.BatchSampler( - dataset=train_dataset, - shuffle=True, - batch_size=training_args.per_device_train_batch_size, - drop_last=training_args.dataloader_drop_last, - ) - - return DistributedBatchSampler( - train_dataset, - batch_size=training_args.per_device_train_batch_size, - shuffle=True, - num_replicas=training_args.dataset_world_size, - rank=training_args.dataset_rank, - drop_last=training_args.dataloader_drop_last, - ) - - -def estimate_gradient(model, train_ds, data_collator, training_args, loraga_init_iters=32, gradient_offload=False): - """Estimate the gradient of the model on the given dataset""" - gradient_dict = {} - logger.info("Estimating gradient for LoraGA.") - - model = wrap_loraga_model(model, training_args) - model.train() - - logger.info(f"Initialization iterations for LoraGA: {loraga_init_iters}") - dataloader = get_loraga_dataloader(train_ds, data_collator, training_args) - iters = 0 - - with GradientOffloadHookContext( - model=model, - gradient_dict=gradient_dict, - local_rank=training_args.local_rank, - loraga_init_iters=loraga_init_iters, - gradient_offload=gradient_offload, - ): - for batch in dataloader: - iters += 1 - batch = {k: paddle.to_tensor(v) for k, v in batch.items()} + Args: + loraga_init_iters (int): The number of forward and backward process in initializing Lora-GA. + gradient_offload (bool): Whether to offload gradients to CPU memory. - # Pipeline parallel not supported currently - loss, logits = model(**batch) - loss.backward() + """ + super().__init__(**kwargs) + logger.info(f"Initialization iterations for LoraGA: {loraga_init_iters}") + self.loraga_init_iters = loraga_init_iters + self.gradient_offload = gradient_offload - if iters == loraga_init_iters: - break + def estimate_gradient(self, model: PretrainedModel): + """ + Estimate the gradient of the model on the given dataset + Args: + model (PretrainedModel): The base model to be trained. + + Returns: + dict: A dictionary containing the estimated gradients for each named layer. + Note: In tensor parallel mode, the gradients in the dict are not gathered. + """ + gradient_dict = {} + logger.info("Estimating gradient for LoraGA.") + + model = self._wrap_model(model) + model.train() + dataloader = self.get_train_dataloader() + iters = 0 + + with GradientOffloadHookContext( + model=model, + gradient_dict=gradient_dict, + local_rank=self.args.local_rank, + loraga_init_iters=self.loraga_init_iters, + gradient_offload=self.gradient_offload, + ): + for batch in dataloader: + iters += 1 + batch = {k: paddle.to_tensor(v) for k, v in batch.items()} + + # Pipeline parallel not supported currently + loss, logits = model(**batch) + loss.backward() + + if iters == self.loraga_init_iters: + break + return gradient_dict + + def _wrap_model(self, model): + """Wrap Model without optimizer, support dp, pp and sharding""" + + in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1 + in_sharding_parallel_mode = self.sharding is not None + in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1 + in_sep_parallel_mode = self.args.sep_parallel_degree > 1 + in_cp_parallel_mode = self.args.context_parallel_degree > 1 + + if in_pipeline_parallel_mode: + raise ValueError("LoRA-GA do not supported pipeline parallel currently.") + + # Multi-gpu training + if self.args.world_size > 1 and (not self.args.use_hybrid_parallel): + # MOE use DDP to broadcaset parameters. + ddp_kwargs = {} + if self.args.ddp_find_unused_parameters is not None: + ddp_kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters + elif isinstance(model, PretrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + ddp_kwargs["find_unused_parameters"] = not any( + hasattr(m, "enable_recompute") and m.enable_recompute for m in model.sublayers(include_self=True) + ) + else: + ddp_kwargs["find_unused_parameters"] = True + model = paddle.DataParallel(model, **ddp_kwargs) + + # sharding + if in_sharding_parallel_mode: + # Sharded DDP! + if self.args.tensor_parallel_degree > 1: + hcg = fleet.get_hybrid_communicate_group() + assert ( + ShardingOption.SHARD_GRAD_OP in self.args.sharding or ShardingOption.SHARD_OP in self.args.sharding + ), "Only support tensor parallel + sharding stage1/stage2 hybrid parallel now." + model = paddle.distributed.fleet.meta_parallel.TensorParallel(model, hcg, strategy=None) + if ShardingOption.SHARD_OP in self.args.sharding: + model = fleet.distributed_model(model) + + if not in_sharding_parallel_mode and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode): + model = fleet.distributed_model(model) - return gradient_dict + return model def get_module_gradient( @@ -208,11 +140,28 @@ def get_module_gradient( dp_degree, local_rank, ): + """ + Gather modules gradient in tensor parallel mode. + Average module gradient in data parallel mode and sharding parallel mode. + + Args: + grad_name (str): The name of the gradient parameter. + base_model_prefix (str): The prefix of the base model's parameter names. + gradient_dict (dict): A dictionary containing the estimated gradients for each named layer. + base_model_split_mappings (dict): A mapping of model keys to merge functions. + sharding_degree (int): The sharding parallel degree. + dp_degree (int): The data parallel degree. + local_rank (int): The local rank of the current process. + + Returns: + Tensor: The processed gradient tensor. + """ + rank_suffix = "_" + str(local_rank) local_grad_name = ".".join(grad_name.split(".")[1:]) + ".weight" + rank_suffix gradient = gradient_dict.pop(local_grad_name).cuda() if tp_degree > 1: - # remove prefix and suffix + # remove prefix and suffix in name model_split_key = local_grad_name.split(base_model_prefix)[-1].rsplit(rank_suffix, 1)[0] if model_split_key in base_model_split_mappings: merge_func = base_model_split_mappings[model_split_key] @@ -242,23 +191,27 @@ def get_module_gradient( return gradient -def loraga_svd_reinit(model, gradient_dict, base_model_split_mappings, stable_gamma, training_args, **kwargs) -> None: +def loraga_svd_reinit( + model: LoRAModel, gradient_dict: dict, stable_gamma: int, training_args: TrainingArguments, **kwargs +) -> None: """ - If Loraga has already been initialized, directly modify the base model weights. - Otherwise, reinitialize and save the initialized model. + Perform SVD to gradients and reinitialize base model weight and lora adapter weight. Args: - model (Any): The model to reinitialize. - gradient_dict (Dict[str, Any]): Dictionary containing gradients. - model_split_mappings (Any): Mappings for model tensor parallelism. - stable_gamma (Any): Stable gamma parameter for Loraga. - training_args (Any): Training arguments. - **kwargs: Additional keyword arguments. - """ + model (LoRAModel): The LoRAModel containing LoRA layers. + gradient_dict (dict): A dictionary containing the estimated gradients for each named layer. + stable_gamma (int): A scaling factor for LoRA-GA initialization. + training_args (TrainingArguments): Training arguments. - lora_split_mapping = None + Returns: + None: Updates the model's weights and LoRA adapter weights in place. + """ tensor_parallel_degree = training_args.tensor_parallel_degree in_tensor_parallel_mode = tensor_parallel_degree > 1 + lora_split_mapping = None + base_model_split_mappings = None + if in_tensor_parallel_mode: + base_model_split_mappings = model.model._get_tensor_parallel_mappings(config=model.config, is_split=False) base_model_prefix = unwrap_model(model).base_model_prefix + "." if in_tensor_parallel_mode: @@ -315,14 +268,15 @@ def loraga_svd_module( loraA_name = ".".join(name.split(".")[1:]) + ".lora_A" loraB_name = ".".join(name.split(".")[1:]) + ".lora_B" - + # Perform SVD to gradients U, S, V = paddle.linalg.svd_lowrank(grads.astype("float32"), q=4 * lora_r, niter=4) V = V.T - # get new low rank adapter after SVD + # get new low-rank adapter after SVD A = U[:, lora_r : 2 * lora_r] B = V[:lora_r, :] - m, n = grads.shape # m: feature_out, n: feature_in + + m, n = grads.shape # If stable_gamma is not -1, scale the matrices A and B by the square root of the stable_gamma if stable_gamma != -1: A = A * m**0.25 / stable_gamma**0.5 @@ -363,6 +317,8 @@ def get_hook_enable(): class GradientOffloadHookContext: + """Context manager for offloading gradient memory to CPU.""" + def __init__( self, model, @@ -373,7 +329,6 @@ def __init__( *args, **kwargs, ): - """Offload gradient to cpu""" self.model = model self.gradient_dict = gradient_dict self.local_rank = local_rank @@ -389,12 +344,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): set_hook_enable(False) def register_gradient_hook(self): + """Register gradient hooks for all model parameters.""" for grad_name, param in self.model.named_parameters(): param._register_backward_hook( self.get_record_gradient_hook(self.model, self.gradient_dict, grad_name, param) ) def get_record_gradient_hook(self, model, gradient_dict, grad_name, param): + """Create a gradient recording hook for a parameter.""" + def record_gradient_hook(*_): if get_hook_enable(): grad = param.grad From b435ed674ef7db286841aa04c82323ff1fa30aa1 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Tue, 10 Dec 2024 11:31:36 +0000 Subject: [PATCH 7/9] fix comment --- paddlenlp/peft/lora/loraga_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/peft/lora/loraga_utils.py b/paddlenlp/peft/lora/loraga_utils.py index a0fbbff7e6a6..27cba7770316 100644 --- a/paddlenlp/peft/lora/loraga_utils.py +++ b/paddlenlp/peft/lora/loraga_utils.py @@ -85,7 +85,7 @@ def estimate_gradient(self, model: PretrainedModel): return gradient_dict def _wrap_model(self, model): - """Wrap Model without optimizer, support dp, pp and sharding""" + """Wrap Model without optimizer, support dp, tp and sharding""" in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1 in_sharding_parallel_mode = self.sharding is not None From 86c5b3341c78d851542f86d83db3a44a86d0e694 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Mon, 16 Dec 2024 06:34:44 +0000 Subject: [PATCH 8/9] fix loraga --- paddlenlp/peft/lora/lora_model.py | 11 ++++++----- paddlenlp/peft/lora/loraga_utils.py | 29 ++++++++++++++++++----------- paddlenlp/trainer/trainer.py | 4 ++++ 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index f52e662918eb..0bfa61364ece 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -165,6 +165,7 @@ def __init__(self, model, lora_config: LoRAConfig) -> None: self.forward = self.model.forward if lora_config.loraga: self.loraga_init_dict = {} + self.reinit_base_model = False logger.info("Mark only lora and trainable_module as trainable.") self.mark_only_lora_as_trainable() @@ -349,11 +350,11 @@ def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict): ) base_name = name.replace("lora_A", "weight") - - # Reinit base model - offset = init_loraA.cuda() @ init_loraB.cuda() - ori_weight = model_state_dict[base_name] - model_state_dict[base_name].set_value(ori_weight - self.lora_config.scaling * offset) + if not self.reinit_base_model: + # Reinit base model + offset = init_loraA.cuda() @ init_loraB.cuda() + ori_weight = model_state_dict[base_name] + model_state_dict[base_name].set_value(ori_weight - self.lora_config.scaling * offset) del model_state_dict gc.collect() self.model.set_state_dict(state_dict) diff --git a/paddlenlp/peft/lora/loraga_utils.py b/paddlenlp/peft/lora/loraga_utils.py index 27cba7770316..5c821772ea32 100644 --- a/paddlenlp/peft/lora/loraga_utils.py +++ b/paddlenlp/peft/lora/loraga_utils.py @@ -74,10 +74,9 @@ def estimate_gradient(self, model: PretrainedModel): ): for batch in dataloader: iters += 1 - batch = {k: paddle.to_tensor(v) for k, v in batch.items()} - # Pipeline parallel not supported currently - loss, logits = model(**batch) + with paddle.amp.auto_cast(enable=True, custom_black_list=self.args.amp_custom_black_list): + loss, logits = model(**batch) loss.backward() if iters == self.loraga_init_iters: @@ -160,13 +159,21 @@ def get_module_gradient( rank_suffix = "_" + str(local_rank) local_grad_name = ".".join(grad_name.split(".")[1:]) + ".weight" + rank_suffix gradient = gradient_dict.pop(local_grad_name).cuda() + + is_fleet_init = True + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + sharding_parallel_group = hcg.get_sharding_parallel_group() + data_parallel_group = hcg.get_data_parallel_group() + except: + is_fleet_init = False + if tp_degree > 1: # remove prefix and suffix in name model_split_key = local_grad_name.split(base_model_prefix)[-1].rsplit(rank_suffix, 1)[0] if model_split_key in base_model_split_mappings: merge_func = base_model_split_mappings[model_split_key] - hcg = fleet.get_hybrid_communicate_group() - model_parallel_group = hcg.get_model_parallel_group() output_tensors = [] dist.all_gather(output_tensors, gradient, group=model_parallel_group) @@ -175,18 +182,17 @@ def get_module_gradient( # sharding if sharding_degree > 1: - hcg = fleet.get_hybrid_communicate_group() - sharding_parallel_group = hcg.get_sharding_parallel_group() if sharding_parallel_group.nranks > 1: - dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=sharding_parallel_group) gradient /= sharding_parallel_group.nranks + # dp if dp_degree > 1: - hcg = fleet.get_hybrid_communicate_group() - data_parallel_group = hcg.get_data_parallel_group() if data_parallel_group.nranks > 1: - dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=data_parallel_group) + if is_fleet_init: + dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=data_parallel_group) + else: + dist.all_reduce(gradient, op=dist.ReduceOp.SUM) gradient /= data_parallel_group.nranks return gradient @@ -250,6 +256,7 @@ def loraga_svd_reinit( lora_split_mapping, **kwargs, ) + model.reinit_base_model = True model.loraga_init_dict = loraga_init_dict diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 57c655736f25..9859927d6946 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -623,6 +623,8 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): self.model, resume_from_checkpoint, ) + if isinstance(self.model, LoRAModel) and self.model.lora_config.loraga: + self.model.reinit_base_model = True logger.info(f"Loading model from {resume_from_checkpoint} using unified checkpoint.") self.runtime_timer.stop() return @@ -635,6 +637,8 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): or isinstance(self.model, ReFTModel) ): self._load_from_peft_checkpoint(resume_from_checkpoint) + if isinstance(self.model, LoRAModel) and self.model.lora_config.loraga: + self.model.reinit_base_model = True self.runtime_timer.stop() return From 275c6233dee2f7471d95fd263c890875d8b71814 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Tue, 17 Dec 2024 13:59:59 +0000 Subject: [PATCH 9/9] change variable name --- paddlenlp/peft/lora/lora_model.py | 6 +++--- paddlenlp/peft/lora/loraga_utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 0bfa61364ece..74d8220e25ce 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -334,18 +334,18 @@ def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict): final_lora, init_lora = paddle.split(concat_tensor, 2, axis=axis) init_dict[name] = init_lora state_dict[name] = final_lora - return final_lora, init_lora + return init_lora for name in state_dict.keys(): if "lora_A" in name: concat_lora_A = state_dict[name] - final_loraA, init_loraA = process_split_and_assign( + init_loraA = process_split_and_assign( name, concat_lora_A, axis=1, init_dict=self.loraga_init_dict, state_dict=state_dict ) loraB_name = name.replace("lora_A", "lora_B") concat_lora_B = state_dict[loraB_name] - final_loraB, init_loraB = process_split_and_assign( + init_loraB = process_split_and_assign( loraB_name, concat_lora_B, axis=0, init_dict=self.loraga_init_dict, state_dict=state_dict ) diff --git a/paddlenlp/peft/lora/loraga_utils.py b/paddlenlp/peft/lora/loraga_utils.py index 5c821772ea32..7400e2e3b88d 100644 --- a/paddlenlp/peft/lora/loraga_utils.py +++ b/paddlenlp/peft/lora/loraga_utils.py @@ -76,7 +76,7 @@ def estimate_gradient(self, model: PretrainedModel): iters += 1 # Pipeline parallel not supported currently with paddle.amp.auto_cast(enable=True, custom_black_list=self.args.amp_custom_black_list): - loss, logits = model(**batch) + loss, _ = model(**batch) loss.backward() if iters == self.loraga_init_iters: