Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【AutoParallel】Add semi autoparallel amp #7985

Merged
merged 5 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
)
from .utils.helper import distributed_file, distributed_isfile # nested_truncate,

try:
from ..quantization.quantization_linear import QuantizationLinear
except:
QuantizationLinear = None

Check warning on line 46 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L43-L46

Added lines #L43 - L46 were not covered by tests

MODEL_NAME = "model"
OPTIMIZER_NAME = "optimizer"
DIST_CKPT_PATH = "dist_ckpt"
Expand Down Expand Up @@ -112,12 +117,22 @@
self.optimizer = dist.shard_optimizer(self.optimizer)
return model, dist_loader

def _wrap_amp_model(self):
def _wrap_amp_model(self, args, model):

Check warning on line 120 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L120

Added line #L120 was not covered by tests
logger.info("Using half precision")
if args.to_static:
return

Check warning on line 123 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L122-L123

Added lines #L122 - L123 were not covered by tests
self.enable_autocast_context_manager = True
self.do_grad_scaling = True if self.args.fp16 else False
self.amp_dtype = "float16" if self.args.fp16 else "bfloat16"
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
self.scaler = dist.shard_scaler(paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss))
if self.args.fp16_opt_level == "O2":
paddle.amp.decorate(

Check warning on line 129 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L127-L129

Added lines #L127 - L129 were not covered by tests
models=model,
level=self.args.fp16_opt_level,
dtype=self.amp_dtype,
master_grad=self.args.amp_master_grad,
excluded_layers=QuantizationLinear,
)

def _get_item_from_loss(self, loss):
if isinstance(loss, paddle.Tensor):
Expand Down
33 changes: 16 additions & 17 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@
)

attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype)
with paddle.amp.auto_cast(False):
attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype)

Check warning on line 157 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L156-L157

Added lines #L156 - L157 were not covered by tests

attn_output = paddle.matmul(attn_weights, value_states)
attn_output = attn_output.transpose([0, 2, 1, 3])
Expand All @@ -177,10 +178,7 @@
if self.config.use_fused_rms_norm:
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)

if paddle.in_dynamic_mode():
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
else:
with paddle.amp.auto_cast(False):

Check warning on line 181 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L181

Added line #L181 was not covered by tests
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states

Expand Down Expand Up @@ -1050,20 +1048,21 @@
masked_lm_labels = dist.reshard(masked_lm_labels, get_mesh(-1), [dist.Replicate(), dist.Replicate()])

# Force entropy same kernel
if isinstance(prediction_scores, paddle.Tensor):
masked_lm_loss = self.loss_func(
prediction_scores.astype("float32")._use_gpudnn(False),
masked_lm_labels.unsqueeze(2),
)
else:
with paddle.amp.auto_cast(False):
if isinstance(prediction_scores, paddle.Tensor):
masked_lm_loss = self.loss_func(

Check warning on line 1053 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1051-L1053

Added lines #L1051 - L1053 were not covered by tests
prediction_scores.astype("float32")._use_gpudnn(False),
masked_lm_labels.unsqueeze(2),
)
else:

masked_lm_loss = self.loss_func(
prediction_scores.astype("float32"),
masked_lm_labels.unsqueeze(2),
)
masked_lm_loss = self.loss_func(

Check warning on line 1059 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1059

Added line #L1059 was not covered by tests
prediction_scores.astype("float32"),
masked_lm_labels.unsqueeze(2),
)

masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32")
loss = paddle.mean(masked_lm_loss)
masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32")
loss = paddle.mean(masked_lm_loss)

Check warning on line 1065 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1064-L1065

Added lines #L1064 - L1065 were not covered by tests
return loss


Expand Down
69 changes: 69 additions & 0 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ function llama_case_list_auto() {
llama_dygraph_auto_bs8_fp32_DP2
llama_dygraph_auto_bs8_fp32_DP2-MP2
llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2
llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2

llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1
llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1
Expand Down Expand Up @@ -1440,6 +1441,74 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}

function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
export FLAGS_call_stack_level=3
export NVIDIA_TF32_OVERRIDE=0

task_name="llama_auto_bs8_fp16_dp2mp2pp2"
case_out_dir="output/$task_name"
case_log_dir="output/$task_name""_log"
rm -rf $case_out_dir
rm -rf $case_log_dir

python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir $case_log_dir run_pretrain_auto.py \
--model_type "llama" \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir $case_out_dir \
--split 949,50,1 \
--max_seq_length 2048 \
--hidden_size 1024 \
--intermediate_size 3072 \
--num_hidden_layers 8 \
--num_attention_heads 32 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--use_flash_attention 0 \
--use_fused_rms_norm 0 \
--fp16 1 \
--fp16_opt_level "O2" \
--amp_master_grad 1 \
--scale_loss 1024 \
--pipeline_parallel_degree 2 \
--tensor_parallel_degree 2 \
--sharding_parallel_degree 1 \
--learning_rate 0.0001 \
--min_learning_rate 0.00001 \
--max_steps 10 \
--save_steps 5000 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--logging_steps 1 \
--dataloader_num_workers 1 \
--sharding "" \
--eval_steps 1000000 \
--disable_tqdm true \
--continue_training 0 \
--recompute 0 \
--do_train \
--do_eval \
--device "gpu" \
--data_impl "mmap" \
--enable_auto_parallel 1 \
--to_static 0 \
--max_grad_norm 1.0 \
>>${log_path}/$FUNCNAME 2>&1
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.38341904
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
############ case end ############

function check_result() {
Expand Down
Loading