diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index be252791d3a2..93dbf5d95f46 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -118,11 +118,19 @@ def _wrap_for_auto(self, model, train_dataloader): dist_loader = self._wrap_for_dist_loader(train_dataloader) if ShardingOption.SHARD_OP in self.args.sharding: - self.optimizer = dist.shard_optimizer(self.optimizer, dist.ShardingStage1()) + self.optimizer = dist.shard_optimizer( + self.optimizer, dist.ShardingStage1(), self.args.gradient_accumulation_steps + ) elif ShardingOption.SHARD_GRAD_OP in self.args.sharding: - self.optimizer = dist.shard_optimizer(self.optimizer, dist.ShardingStage2()) + self.optimizer = dist.shard_optimizer( + self.optimizer, dist.ShardingStage2(), self.args.gradient_accumulation_steps + ) elif ShardingOption.FULL_SHARD in self.args.sharding: - self.optimizer = dist.shard_optimizer(self.optimizer, dist.ShardingStage3()) + self.optimizer = dist.shard_optimizer( + self.optimizer, dist.ShardingStage3(), self.args.gradient_accumulation_steps + ) + else: + self.optimizer = dist.shard_optimizer(self.optimizer, None, self.args.gradient_accumulation_steps) if self.args.to_static: unified_strategy = dist.Strategy() @@ -173,41 +181,54 @@ def _split_batches_for_accumulation(self, inputs): if self.args.to_static and self._in_pir_mode and self.args.gradient_accumulation_steps > 1: return [inputs] - local_batches = [{} for i in range(self.args.gradient_accumulation_steps)] + global_micro_batchs = [{} for i in range(self.args.gradient_accumulation_steps)] assert isinstance(inputs, dict) - def split_dtensor_by_axis(dtensor, axis): - mesh = dtensor.process_mesh - placements = [dist.Replicate() for _ in range(len(mesh.shape))] - replicate_value = dist.reshard(dtensor, mesh, placements) - local_datas = replicate_value.split(self.args.gradient_accumulation_steps, axis=0) - return local_datas + def split_dtensor_by_axis(dtensor, axis=0): + if not dtensor._is_initialized(): + return dtensor.split(self.args.gradient_accumulation_steps, axis=axis) + + micro_batch_shape = dtensor.shape + micro_batch_shape[axis] = int(dtensor.shape[axis] / self.args.gradient_accumulation_steps) + + global_micro_batchs = [ + paddle.zeros(micro_batch_shape, dtype=dtensor.dtype) + for _ in range(self.args.gradient_accumulation_steps) + ] + global_micro_batchs = [ + dist.shard_tensor(b, dtensor.process_mesh, dtensor.placements) for b in global_micro_batchs + ] + + local_micro_batchs = dtensor._local_value().split(self.args.gradient_accumulation_steps, axis=axis) + for local_micro_batch, global_micro_batch in zip(local_micro_batchs, global_micro_batchs): + paddle.assign(local_micro_batch, global_micro_batch._local_value()) + return global_micro_batchs for key, dtensors in inputs.items(): if isinstance(dtensors, paddle.Tensor): mesh, placements = dtensors.process_mesh, dtensors.placements - local_datas = split_dtensor_by_axis(dtensors, 0) - for index, data in enumerate(local_datas): - local_batches[index].update({key: dist.reshard(data, mesh, placements)}) + global_datas = split_dtensor_by_axis(dtensors, 0) + for index, data in enumerate(global_datas): + global_micro_batchs[index].update({key: dist.reshard(data, mesh, placements)}) elif isinstance(dtensors, (list, tuple)): if len(dtensors) == 0: for i in range(self.args.gradient_accumulation_steps): - local_batches[i].update({key: []}) + global_micro_batchs[i].update({key: []}) else: for dtensor in dtensors: if isinstance(dtensor, paddle.Tensor): mesh, placements = dtensor.process_mesh, dtensor.placements - local_datas = split_dtensor_by_axis(dtensor, 0) - for index, data in enumerate(local_datas): - if key in local_batches[index].keys(): - local_batches[index][key].append(dist.reshard(data, mesh, placements)) + global_datas = split_dtensor_by_axis(dtensor, 0) + for index, data in enumerate(global_datas): + if key in global_micro_batchs[index].keys(): + global_micro_batchs[index][key].append(dist.reshard(data, mesh, placements)) else: - local_batches[index].update({key: [dist.reshard(data, mesh, placements)]}) + global_micro_batchs[index].update({key: [dist.reshard(data, mesh, placements)]}) else: raise ValueError(f"unsupported type: {type(dtensor)}") else: raise ValueError(f"unsupported type: {type(dtensors)}") - return local_batches + return global_micro_batchs def _inner_training_loop( self, @@ -372,6 +393,9 @@ def _inner_training_loop( self.timers and self.timers("optimizer-step").start() + if self.args.gradient_accumulation_steps > 1 and self._enable_delay_scale_loss(): + tr_loss /= self.args.gradient_accumulation_steps + # Optimizer step self.callback_handler.on_optimizer_begin( args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None @@ -503,11 +527,11 @@ def to_list(value): return (loss, outputs) if return_outputs else loss - def dynamic_traning(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: + def dynamic_training(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: with self.autocast_smart_context_manager(): loss = self.compute_loss(model, inputs) - if loss is not None and self.args.gradient_accumulation_steps > 1: + if loss is not None and self.args.gradient_accumulation_steps > 1 and not self._enable_delay_scale_loss(): loss = loss / self.args.gradient_accumulation_steps if self.do_grad_scaling: @@ -517,11 +541,11 @@ def dynamic_traning(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor return loss - def static_traning(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: + def static_training(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: input_ids, labels = tuple(inputs.values()) loss = model(input_ids, labels) - if loss is not None and self.args.gradient_accumulation_steps > 1: + if loss is not None and self.args.gradient_accumulation_steps > 1 and not self._enable_delay_scale_loss(): loss = loss / self.args.gradient_accumulation_steps return loss @@ -532,9 +556,9 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, inputs = self._prepare_inputs(inputs) if not self.args.to_static: - loss = self.dynamic_traning(model, inputs) + loss = self.dynamic_training(model, inputs) else: - loss = self.static_traning(model, inputs) + loss = self.static_training(model, inputs) if isinstance(loss, paddle.Tensor): return loss.detach() if loss._is_initialized() else float(0.0) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index c7cfa72462b4..c3e7d91f74ba 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2197,6 +2197,9 @@ def compute_loss(self, model, inputs, return_outputs=False): return (loss, outputs) if return_outputs else loss def _enable_delay_scale_loss(self): + if in_auto_parallel_align_mode(): + return True + key = "enable_delay_scale_loss" if self.args.pipeline_parallel_degree > 1: return key in self.args.pipeline_parallel_config diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 6aed76f17cea..567349dfde1c 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1472,7 +1472,7 @@ def is_segment_parallel_supported(): "enable_send_recv_overlap", # "disable_p2p_cache_shape", # no need for auto_parallel # "disable_partial_send_recv", # no implemenation for auto_parallel - # "enable_delay_scale_loss", # default True in auto_parallel, non-configurable + "enable_delay_scale_loss", # "enable_dp_comm_overlap", # no implemenation for auto_parallel # "enable_sharding_comm_overlap", # no implemenation for auto_parallel # "enable_timer", # no implemenation for auto_parallel @@ -1521,6 +1521,7 @@ def is_segment_parallel_supported(): if len(x) > 0: if x not in [ "enable_mp_async_allreduce", # allreduce_matmul_grad_overlapping in auto_parallel + "enable_delay_scale_loss", # "enable_mp_skip_c_identity", # "enable_mp_fused_linear_param_grad_add", ]: diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 3d329556ccee..e1bfef19e1d3 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -561,7 +561,7 @@ function llama_dygraph_auto_bs8_fp32_DP2() { echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.51876831 if [ $IS_A100 -ne 0 ];then - loss_base=9.53083992 + loss_base=9.53084087 fi ips_base=-1 mem_base=-1 @@ -629,9 +629,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() { ips=-1 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" - loss_base=9.35078526 + loss_base=9.3507843 if [ $IS_A100 -ne 0 ];then - loss_base=9.38577652 + loss_base=9.38577747 fi ips_base=-1 mem_base=-1 @@ -699,7 +699,7 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { ips=-1 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" - loss_base=9.35139465 + loss_base=9.3513937 if [ $IS_A100 -ne 0 ];then loss_base=9.39356422 fi @@ -770,7 +770,7 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { ips=-1 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" - loss_base=9.35162354 + loss_base=9.35162258 if [ $IS_A100 -ne 0 ];then loss_base=9.39368534 fi @@ -1063,7 +1063,7 @@ function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP() { loss_base=9.16783295 loss_md5_base=8ea72495fba4e1b9ba004b4431e27218 if [ $IS_A100 -ne 0 ];then - loss_base=9.37966919 + loss_base=9.38009949 fi ips_base=-1 mem_base=-1 @@ -1253,7 +1253,7 @@ function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP() { loss_base=9.25199432 loss_md5_base=83531e98ee11cd271db175150ab254bb if [ $IS_A100 -ne 0 ];then - loss_base=9.44203949 + loss_base=9.44241714 fi ips_base=-1 mem_base=-1 @@ -1546,10 +1546,10 @@ function llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4() { >>${log_path}/$FUNCNAME 2>&1 loss=$(grep "global_step: 10," "$case_log_dir/workerlog.0" | grep -oP '(?<=loss: )\d+(\.\d+)?' | awk -F ',' '{print $1}') - if [ "$pp_mode" == "FThenB" ]; then - loss1=loss + if [ "$pp_mode" == "1F1B" ]; then + loss1=($loss) else - loss2=loss + loss2=($loss) fi echo "result: $pp_mode loss=$loss" done @@ -1887,12 +1887,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2() { ips=-1 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" - loss_base=10.59205246 - loss_md5_base=0ebf68698887b33b33a46518621cf412 + loss_base=10.59368134 ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then - loss_base=10.60499191 + loss_base=10.60190201 fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" @@ -1960,12 +1959,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2() { ips=-1 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" - loss_base=10.58860683 - loss_md5_base=6df87d01bd08113a92930f6349514b35 + loss_base=10.5913763 ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then - loss_base=10.59338379 + loss_base=10.5915575 fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" @@ -2034,12 +2032,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" # loss_base=10.59993172 # note: need to debug - loss_base=10.59993267 - loss_md5_base=6cb4e151b35f026190df90ab240d9a95 + loss_base=10.59891224 ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then - loss_base=10.59612274 + loss_base=10.60014629 fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" @@ -2108,13 +2105,12 @@ function llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" # loss_base=10.58456802 # note: need to debug - loss_base=10.6004734 - loss_md5_base=e82a1f5668870d18a2d45b3ee0a25386 + loss_base=10.59941483 ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then # loss_base=10.58141422 # note: need to debug - loss_base=10.59650803 + loss_base=10.60039139 fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ==========="