Skip to content

Commit

Permalink
【AutoParallel】Add semi autoparallel amp (#7985)
Browse files Browse the repository at this point in the history
* add semi-autoparallel amp

* support amp in semi-auto

* change loss base

* polish
  • Loading branch information
heavyrain-lzy authored Feb 21, 2024
1 parent 5c9c8d3 commit b2be2fc
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 19 deletions.
19 changes: 17 additions & 2 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
)
from .utils.helper import distributed_file, distributed_isfile # nested_truncate,

try:
from ..quantization.quantization_linear import QuantizationLinear
except:
QuantizationLinear = None

MODEL_NAME = "model"
OPTIMIZER_NAME = "optimizer"
DIST_CKPT_PATH = "dist_ckpt"
Expand Down Expand Up @@ -112,12 +117,22 @@ def _wrap_for_auto(self, model, train_dataloader):
self.optimizer = dist.shard_optimizer(self.optimizer)
return model, dist_loader

def _wrap_amp_model(self):
def _wrap_amp_model(self, args, model):
logger.info("Using half precision")
if args.to_static:
return
self.enable_autocast_context_manager = True
self.do_grad_scaling = True if self.args.fp16 else False
self.amp_dtype = "float16" if self.args.fp16 else "bfloat16"
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
self.scaler = dist.shard_scaler(paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss))
if self.args.fp16_opt_level == "O2":
paddle.amp.decorate(
models=model,
level=self.args.fp16_opt_level,
dtype=self.amp_dtype,
master_grad=self.args.amp_master_grad,
excluded_layers=QuantizationLinear,
)

def _get_item_from_loss(self, loss):
if isinstance(loss, paddle.Tensor):
Expand Down
33 changes: 16 additions & 17 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def scaled_dot_product_attention(
)

attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype)
with paddle.amp.auto_cast(False):
attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype)

attn_output = paddle.matmul(attn_weights, value_states)
attn_output = attn_output.transpose([0, 2, 1, 3])
Expand All @@ -177,10 +178,7 @@ def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)

if paddle.in_dynamic_mode():
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
else:
with paddle.amp.auto_cast(False):
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states

Expand Down Expand Up @@ -1050,20 +1048,21 @@ def forward(self, prediction_scores, masked_lm_labels):
masked_lm_labels = dist.reshard(masked_lm_labels, get_mesh(-1), [dist.Replicate(), dist.Replicate()])

# Force entropy same kernel
if isinstance(prediction_scores, paddle.Tensor):
masked_lm_loss = self.loss_func(
prediction_scores.astype("float32")._use_gpudnn(False),
masked_lm_labels.unsqueeze(2),
)
else:
with paddle.amp.auto_cast(False):
if isinstance(prediction_scores, paddle.Tensor):
masked_lm_loss = self.loss_func(
prediction_scores.astype("float32")._use_gpudnn(False),
masked_lm_labels.unsqueeze(2),
)
else:

masked_lm_loss = self.loss_func(
prediction_scores.astype("float32"),
masked_lm_labels.unsqueeze(2),
)
masked_lm_loss = self.loss_func(
prediction_scores.astype("float32"),
masked_lm_labels.unsqueeze(2),
)

masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32")
loss = paddle.mean(masked_lm_loss)
masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32")
loss = paddle.mean(masked_lm_loss)
return loss


Expand Down
69 changes: 69 additions & 0 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ function llama_case_list_auto() {
llama_dygraph_auto_bs8_fp32_DP2
llama_dygraph_auto_bs8_fp32_DP2-MP2
llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2
llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2

llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1
llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1
Expand Down Expand Up @@ -1440,6 +1441,74 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}

function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
export FLAGS_call_stack_level=3
export NVIDIA_TF32_OVERRIDE=0

task_name="llama_auto_bs8_fp16_dp2mp2pp2"
case_out_dir="output/$task_name"
case_log_dir="output/$task_name""_log"
rm -rf $case_out_dir
rm -rf $case_log_dir

python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir $case_log_dir run_pretrain_auto.py \
--model_type "llama" \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir $case_out_dir \
--split 949,50,1 \
--max_seq_length 2048 \
--hidden_size 1024 \
--intermediate_size 3072 \
--num_hidden_layers 8 \
--num_attention_heads 32 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--use_flash_attention 0 \
--use_fused_rms_norm 0 \
--fp16 1 \
--fp16_opt_level "O2" \
--amp_master_grad 1 \
--scale_loss 1024 \
--pipeline_parallel_degree 2 \
--tensor_parallel_degree 2 \
--sharding_parallel_degree 1 \
--learning_rate 0.0001 \
--min_learning_rate 0.00001 \
--max_steps 10 \
--save_steps 5000 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--logging_steps 1 \
--dataloader_num_workers 1 \
--sharding "" \
--eval_steps 1000000 \
--disable_tqdm true \
--continue_training 0 \
--recompute 0 \
--do_train \
--do_eval \
--device "gpu" \
--data_impl "mmap" \
--enable_auto_parallel 1 \
--to_static 0 \
--max_grad_norm 1.0 \
>>${log_path}/$FUNCNAME 2>&1
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.38341904
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
############ case end ############

function check_result() {
Expand Down

0 comments on commit b2be2fc

Please sign in to comment.