Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Training got stuck when pipeline is used. #1027

Open
sandyhouse opened this issue Aug 23, 2024 · 0 comments
Open

[BUG] Training got stuck when pipeline is used. #1027

sandyhouse opened this issue Aug 23, 2024 · 0 comments

Comments

@sandyhouse
Copy link

sandyhouse commented Aug 23, 2024

Describe the bug
When I used pipeline parallelism to train a 7B GPT model, the program hung.

To Reproduce
I used the following script for training:

#!/bin/bash

# Runs the "7B" parameter model

export CUDA_DEVICE_MAX_CONNECTIONS=1

export OMP_NUM_THREADS=1

GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NUM_NODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES))

CHECKPOINT_PATH=$1 #<Specify path>
TENSORBOARD_LOGS_PATH=$2 #<Specify path>
#VOCAB_FILE=$3 #<Specify path to file>/gpt2-vocab.json
#MERGE_FILE=$4 #<Specify path to file>/gpt2-merges.txt
DATA_PATH=$3 #<Specify path and file prefix>_text_document
TOKENIZER_MODEL_PATH=$4

DISTRIBUTED_ARGS=(
    --nproc_per_node $GPUS_PER_NODE
    --nnodes $NUM_NODES
    --master_addr $MASTER_ADDR
    --master_port $MASTER_PORT
)

GPT_MODEL_ARGS=(
    --use-mcore-models
    --transformer-impl "transformer_engine"
    --num-layers 32
    --hidden-size 4096
    --num-attention-heads 32
    --seq-length 4096
    --max-position-embeddings 4096
)

TRAINING_ARGS=(
    --micro-batch-size 1
    --global-batch-size 128
    --train-iters 275
    --weight-decay 0.1
    --adam-beta1 0.9
    --adam-beta2 0.95
    --init-method-std 0.006
    --clip-grad 1.0
    --fp16
    --lr 3.0e-4
    --lr-decay-style cosine
    --min-lr 3.0e-5
    --lr-warmup-fraction .001
    --lr-decay-iters 275
)

MODEL_PARALLEL_ARGS=(
        --tensor-model-parallel-size 1
        --pipeline-model-parallel-size 4
        --use-distributed-optimizer
        --sequence-parallel
)

DATA_ARGS=(
    --data-path $DATA_PATH
    --split 100,0,0
    --tokenizer-type Llama2Tokenizer
    --tokenizer-model ${TOKENIZER_MODEL_PATH}
    --data-cache-path $DATA_CACHE_PATH
)

EVAL_AND_LOGGING_ARGS=(
    --log-interval 1
    --save-interval 10000
    --eval-interval 1000
    --save $CHECKPOINT_PATH
    --load $CHECKPOINT_PATH
    --eval-iters 0
    --tensorboard-dir $TENSORBOARD_LOGS_PATH
    --log-throughput
)

torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
    ${GPT_MODEL_ARGS[@]} \
    ${TRAINING_ARGS[@]} \
    ${MODEL_PARALLEL_ARGS[@]} \
    ${DATA_ARGS[@]} \
    ${EVAL_AND_LOGGING_ARGS[@]}

I found that this issue occurred because the device_id parameter was set when calling the torch.distributed.init_process_group function. This caused processes that were not members of the new_group to use ncclCommSplit to create a child communicator. However, for processes that were members of the new_group, the communicator was not created until the first communication operation was called, causing the non-new_group member processes to keep waiting on the parent communicator for other processes to create the communicator. In other words, this violates the convention that all ranks must call ncclCommSplit on the original communicator.

# in megatron/training/initialize.py: 431L (_initialize_distributed function)
if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"):
    init_process_group_kwargs['device_id'] = device_id

A detailed description can be found at pytorch/pytorch#134314.

Environment (please complete the following information):

  • Megatron-LM commit ID: 85bd1f9
  • PyTorch version: 2.4.0a0+f70bd71a48.nv24.6 (NGC
  • CUDA version: 12.5
  • NCCL version: 2.21.5+cuda12.5

Proposed fix
Before this issue is fixed in PyTorch, I think we should avoid passing the device_id parameter to torch.distributed.init_process_group, thereby falling back to the original behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant