diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index 7ac4873821c8..792e7d1014c1 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -20,6 +20,7 @@ from typing import List, Optional import paddle +from paddle.io.reader import use_pinned_memory from paddlenlp.data.causal_dataset import ( build_train_valid_test_datasets, @@ -47,7 +48,6 @@ from paddlenlp.utils.batch_sampler import DistributedBatchSampler from paddlenlp.utils.log import logger from paddlenlp.utils.tools import get_env_device -from paddle.io.reader import use_pinned_memory # Pretaining Environment Variables to support sharding stage1 overlap optimization. os.environ["USE_CASUAL_MASK"] = "True" @@ -403,6 +403,10 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() + print("--888--" * 100) + print("training_args:", training_args) + print("--888--" * 100) + if training_args.enable_linear_fused_grad_add: from fused_layers import mock_layers @@ -499,6 +503,14 @@ def main(): config.seq_length % config.context_parallel_degree == 0 ), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}" + if training_args.sharding_parallel_config is not None: + # for stage1 overlap optimization + if ( + "enable_stage1_allgather_overlap" in training_args.sharding_parallel_config + or "enable_stage1_broadcast_overlap" in training_args.sharding_parallel_config + ): + use_pinned_memory(False) + if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: try: from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 @@ -635,5 +647,4 @@ def main(): if __name__ == "__main__": - use_pinned_memory(False) main()