Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fit for llama3 for auto_parallel #8395

Merged
merged 2 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions llm/llama/auto_parallel/run_llama3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# just for debug

set -x
unset CUDA_VISIBLE_DEVICES

task_name="llama3_dp2pp4sd2"
rm -rf output/$task_name/
rm -rf "output/$task_name""_log"

export SOT_LOG_LEVEL=4
export PYTHONPATH=../../../:$PYTHONPATH

#ulimit -c unlimited
#export GLOG_v=10

# export FLAGS_call_stack_level=3
# export FLAGS_use_cuda_managed_memory=true

# export FLAGS_embedding_deterministic=1
# export FLAGS_cudnn_deterministic=1
# export NVIDIA_TF32_OVERRIDE=0

python -u -m paddle.distributed.launch \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir "output/$task_name""_log" \
./run_pretrain_auto.py \
--model_name_or_path "meta-llama/Meta-Llama-3-8B-Instruct" \
--tokenizer_name_or_path "meta-llama/Meta-Llama-3-8B-Instruct" \
--input_dir "./data" \
--output_dir "./output" \
--split 949,50,1 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--max_steps 30 \
--logging_steps 10 \
--eval_steps 1000 \
--save_steps 50000 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--save_total_limit 2 \
--device gpu \
--disable_tqdm true \
--dataloader_num_workers 1 \
--distributed_dataloader 0 \
--enable_auto_parallel 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 4 \
--per_device_eval_batch_size 1 \
--recompute false \
--recompute_use_reentrant true \
--recompute_granularity full \
--pp_recompute_interval 0 \
--bf16 true \
--fp16_opt_level "O2" \
--amp_master_grad true \
--fuse_attention_ffn false \
--fuse_attention_qkv false \
--fused_linear_param_grad_add 1 \
--fuse_sequence_parallel_allreduce false \
--use_flash_attention true \
--use_fused_rope true \
--use_fused_rms_norm true \
--max_seq_length 4096 \
--sep_parallel_degree 1 \
--sequence_parallel false \
--pipeline_parallel_degree 2 \
--sharding_parallel_degree 2 \
--tensor_parallel_degree 1 \
--virtual_pp_degree 4 \
--pipeline_schedule_mode "VPP" \
--sharding "stage2" \
--pipeline_parallel_config "enable_send_recv_overlap" \
--data_parallel_config "enable_allreduce_avg_in_gradinent_scale gradient_sync_after_accumulate" \
--sharding_parallel_config "enable_stage2_overlap" \
--tensor_parallel_config "enable_mp_async_allreduce" \
--to_static 1 \
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
--skip_memory_metrics 0

4 changes: 4 additions & 0 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,24 +367,28 @@ def _init_rope(self):
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
)
elif self.config.rope_scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=self.config.rope_scaling_factor,
base=self.config.rope_theta,
)
elif self.config.rope_scaling_type == "ntk":
self.rotary_emb = LlamaNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=self.config.rope_scaling_factor,
base=self.config.rope_theta,
)
elif self.config.rope_scaling_type == "dynamic_ntk":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=self.config.rope_scaling_factor,
base=self.config.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {self.config.rope_scaling_type}")
Expand Down
Loading