From e3cb5d239b91a370a2a91f113ab22052d5711748 Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Tue, 5 Mar 2024 14:17:42 +0800 Subject: [PATCH] Fit sharding optimization for auto parallel llama (#8021) * Fit sharding optimization for auto parallel llama * Add args enable_allreduce_avg_in_gradinent_scale * Fix CI errors --- paddlenlp/trainer/training_args.py | 37 ++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index d0717a4ba787..26689989ed7f 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -230,6 +230,10 @@ class TrainingArguments: The paddle sequence parallel strategy. It can reduce the GPU memory of activation to 1/sep, and it is orthogonal to data parallel, sharding stage1, tensor parallel and pipeline parallel strategy. ) + data_parallel_config (`str`, *optional*)( + Some additional configs which affect data parallel performance, we provide some option to config it. + following config is support: + enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now. tensor_parallel_config (`str`, *optional*)( Some additional configs which affect model parallel performance, we provide some option to config it. following config is support: @@ -571,6 +575,16 @@ class TrainingArguments: ) }, ) + data_parallel_config: str = field( + default="", + metadata={ + "help": ( + "Some additional configs which affect data parallel performance, we provide some option to config it." + "following config is support:\n" + "enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now. \n" + ) + }, + ) tensor_parallel_config: str = field( default="", metadata={ @@ -951,6 +965,7 @@ def __post_init__(self): # TODO use paddle.distributed.is_initialized() after paddle 2.4rc if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(): strategy = fleet.DistributedStrategy() + assert self.data_parallel_config == "", "data_parallle_config is not supported in hybrid parallel" if self.pipeline_parallel_degree > 1: pipeline_parallel_config = set(self.pipeline_parallel_config.split(" ")) for x in pipeline_parallel_config: @@ -1165,6 +1180,17 @@ def is_segment_parallel_supported(): warnings.warn("`offload` is not supported NOW!") strategy = fleet.auto.Strategy() + if self.data_parallel_degree > 1: + data_parallel_config = set(self.data_parallel_config.split(" ")) + for x in data_parallel_config: + if len(x) > 0: + if x not in ["enable_allreduce_avg_in_gradinent_scale"]: + raise ValueError( + f"Found unknown data parallel config {x}, accpet config is enable_allreduce_avg_in_gradinent_scale." + ) + if "enable_allreduce_avg_in_gradinent_scale" in data_parallel_config: + strategy.gradient_scale_using_allreduce_avg = True + # navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1 if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1: pipeline_parallel_config = set(self.pipeline_parallel_config.split(" ")) @@ -1254,9 +1280,9 @@ def is_segment_parallel_supported(): for x in sharding_parallel_config: if len(x) > 0: if x not in [ - # "enable_stage1_tensor_fusion", - # "enable_stage1_overlap", - # "enable_stage2_overlap", + "enable_stage1_tensor_fusion", + "enable_stage1_overlap", + "enable_stage2_overlap", ]: raise ValueError( f"Found unknown pipeline mode config {x}, " f"accpet config is reduce_overlap." @@ -1266,7 +1292,10 @@ def is_segment_parallel_supported(): "enable_stage1_overlap" in sharding_parallel_config or "enable_stage2_overlap" in sharding_parallel_config ): - sharding.reduce_overlap = True + sharding.enable_overlap = True + + if "enable_stage1_tensor_fusion" in sharding_parallel_config: + sharding.grad_bucket_size_numel = 210355872 if self.bf16 or self.fp16: amp = strategy.amp