Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
GhostScreaming committed Jun 13, 2024
1 parent 28590c5 commit d1613ac
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import List, Optional

import paddle
from paddle.io.reader import use_pinned_memory

from paddlenlp.data.causal_dataset import (
build_train_valid_test_datasets,
Expand Down Expand Up @@ -47,7 +48,6 @@
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device
from paddle.io.reader import use_pinned_memory

# Pretaining Environment Variables to support sharding stage1 overlap optimization.
os.environ["USE_CASUAL_MASK"] = "True"
Expand Down Expand Up @@ -403,6 +403,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

print("--888--" * 100)
print("training_args:", training_args)
print("--888--" * 100)

if training_args.enable_linear_fused_grad_add:
from fused_layers import mock_layers

Expand Down Expand Up @@ -499,6 +503,14 @@ def main():
config.seq_length % config.context_parallel_degree == 0
), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}"

if training_args.sharding_parallel_config is not None:
# for stage1 overlap optimization
if (
"enable_stage1_allgather_overlap" in training_args.sharding_parallel_config
or "enable_stage1_broadcast_overlap" in training_args.sharding_parallel_config
):
use_pinned_memory(False)

if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
try:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401
Expand Down Expand Up @@ -635,5 +647,4 @@ def main():


if __name__ == "__main__":
use_pinned_memory(False)
main()

0 comments on commit d1613ac

Please sign in to comment.