diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 2413855c4b4d..c6d6f2962daf 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -40,6 +40,11 @@ ) from .utils.helper import distributed_file, distributed_isfile # nested_truncate, +try: + from ..quantization.quantization_linear import QuantizationLinear +except: + QuantizationLinear = None + MODEL_NAME = "model" OPTIMIZER_NAME = "optimizer" DIST_CKPT_PATH = "dist_ckpt" @@ -112,12 +117,22 @@ def _wrap_for_auto(self, model, train_dataloader): self.optimizer = dist.shard_optimizer(self.optimizer) return model, dist_loader - def _wrap_amp_model(self): + def _wrap_amp_model(self, args, model): logger.info("Using half precision") + if args.to_static: + return self.enable_autocast_context_manager = True self.do_grad_scaling = True if self.args.fp16 else False self.amp_dtype = "float16" if self.args.fp16 else "bfloat16" - self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss) + self.scaler = dist.shard_scaler(paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)) + if self.args.fp16_opt_level == "O2": + paddle.amp.decorate( + models=model, + level=self.args.fp16_opt_level, + dtype=self.amp_dtype, + master_grad=self.args.amp_master_grad, + excluded_layers=QuantizationLinear, + ) def _get_item_from_loss(self, loss): if isinstance(loss, paddle.Tensor): diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 7259c9f4c657..3905bf4f9efe 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -153,7 +153,8 @@ def scaled_dot_product_attention( ) attn_weights = attn_weights + attention_mask - attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + with paddle.amp.auto_cast(False): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) attn_output = paddle.matmul(attn_weights, value_states) attn_output = attn_output.transpose([0, 2, 1, 3]) @@ -177,10 +178,7 @@ def forward(self, hidden_states): if self.config.use_fused_rms_norm: return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) - if paddle.in_dynamic_mode(): - variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) - hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states - else: + with paddle.amp.auto_cast(False): variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states @@ -1050,20 +1048,21 @@ def forward(self, prediction_scores, masked_lm_labels): masked_lm_labels = dist.reshard(masked_lm_labels, get_mesh(-1), [dist.Replicate(), dist.Replicate()]) # Force entropy same kernel - if isinstance(prediction_scores, paddle.Tensor): - masked_lm_loss = self.loss_func( - prediction_scores.astype("float32")._use_gpudnn(False), - masked_lm_labels.unsqueeze(2), - ) - else: + with paddle.amp.auto_cast(False): + if isinstance(prediction_scores, paddle.Tensor): + masked_lm_loss = self.loss_func( + prediction_scores.astype("float32")._use_gpudnn(False), + masked_lm_labels.unsqueeze(2), + ) + else: - masked_lm_loss = self.loss_func( - prediction_scores.astype("float32"), - masked_lm_labels.unsqueeze(2), - ) + masked_lm_loss = self.loss_func( + prediction_scores.astype("float32"), + masked_lm_labels.unsqueeze(2), + ) - masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") - loss = paddle.mean(masked_lm_loss) + masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") + loss = paddle.mean(masked_lm_loss) return loss diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 825b22683989..93e58b38f6cb 100644 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -48,6 +48,7 @@ function llama_case_list_auto() { llama_dygraph_auto_bs8_fp32_DP2 llama_dygraph_auto_bs8_fp32_DP2-MP2 llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2 + llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2 llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1 llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1 @@ -1440,6 +1441,74 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } + +function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { + echo "=========== $FUNCNAME run begin ===========" + export PYTHONPATH=$root_path/:$PYTHONPATH + export FLAGS_call_stack_level=3 + export NVIDIA_TF32_OVERRIDE=0 + + task_name="llama_auto_bs8_fp16_dp2mp2pp2" + case_out_dir="output/$task_name" + case_log_dir="output/$task_name""_log" + rm -rf $case_out_dir + rm -rf $case_log_dir + + python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir $case_log_dir run_pretrain_auto.py \ + --model_type "llama" \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir $case_out_dir \ + --split 949,50,1 \ + --max_seq_length 2048 \ + --hidden_size 1024 \ + --intermediate_size 3072 \ + --num_hidden_layers 8 \ + --num_attention_heads 32 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --use_flash_attention 0 \ + --use_fused_rms_norm 0 \ + --fp16 1 \ + --fp16_opt_level "O2" \ + --amp_master_grad 1 \ + --scale_loss 1024 \ + --pipeline_parallel_degree 2 \ + --tensor_parallel_degree 2 \ + --sharding_parallel_degree 1 \ + --learning_rate 0.0001 \ + --min_learning_rate 0.00001 \ + --max_steps 10 \ + --save_steps 5000 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --logging_steps 1 \ + --dataloader_num_workers 1 \ + --sharding "" \ + --eval_steps 1000000 \ + --disable_tqdm true \ + --continue_training 0 \ + --recompute 0 \ + --do_train \ + --do_eval \ + --device "gpu" \ + --data_impl "mmap" \ + --enable_auto_parallel 1 \ + --to_static 0 \ + --max_grad_norm 1.0 \ + >>${log_path}/$FUNCNAME 2>&1 + loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + ips=-1 + mem=-1 + echo "result: loss=$loss ips=$ips mem=$mem" + loss_base=9.38341904 + ips_base=-1 + mem_base=-1 + check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} + echo "=========== $FUNCNAME run end ===========" +} ############ case end ############ function check_result() {