Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error pre-training BERT #1043

Open
fabiancpl opened this issue Jul 25, 2024 · 1 comment
Open

Error pre-training BERT #1043

fabiancpl opened this issue Jul 25, 2024 · 1 comment
Assignees

Comments

@fabiancpl
Copy link

Hi guys,

I am following the Megatron-LM example to pre-train a BERT model but I'm getting this error:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/Megatron-LM/pretrain_bert.py", line 193, in <module>
[rank0]:     pretrain(train_valid_test_datasets_provider, model_provider,
[rank0]:   File "/root/Megatron-LM/megatron/training/training.py", line 274, in pretrain
[rank0]:     iteration, num_floating_point_operations_so_far = train(
[rank0]:                                                       ^^^^^^
[rank0]:   File "/root/Megatron-LM/megatron/training/training.py", line 1027, in train
[rank0]:     train_step(forward_step_func,
[rank0]:   File "/root/Megatron-LM/megatron/training/training.py", line 550, in train_step
[rank0]:     losses_reduced = forward_backward_func(
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 381, in forward_backward_no_pipelining
[rank0]:     output_tensor, num_tokens = forward_step(
[rank0]:                                 ^^^^^^^^^^^^^
[rank0]:   File "/root/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 206, in forward_step
[rank0]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank0]:                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/Megatron-LM/pretrain_bert.py", line 139, in forward_step
[rank0]:     output_tensor = model(tokens, padding_mask,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py", line 180, in forward
[rank0]:     return self.module(*inputs, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/Megatron-LM/megatron/legacy/model/module.py", line 190, in forward
[rank0]:     outputs = self.module(*inputs, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/Megatron-LM/megatron/core/models/bert/bert_model.py", line 237, in forward
[rank0]:     hidden_states = self.encoder(
[rank0]:                     ^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/Megatron-LM/megatron/core/transformer/transformer_block.py", line 383, in forward
[rank0]:     hidden_states, context = layer(
[rank0]:                              ^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 178, in forward
[rank0]:     attention_output_with_bias = self.self_attention(
[rank0]:                                  ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/Megatron-LM/megatron/core/transformer/attention.py", line 315, in forward
[rank0]:     core_attn_out = self.core_attention(
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/Megatron-LM/megatron/core/transformer/custom_layers/transformer_engine.py", line 514, in forward
[rank0]:     core_attn_out = super().forward(
[rank0]:                     ^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/transformer_engine/pytorch/attention.py", line 5267, in forward
[rank0]:     return self.fused_attention(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/transformer_engine/pytorch/attention.py", line 4157, in forward
[rank0]:     cu_seqlens_q = get_cu_seqlens(attention_mask)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/transformer_engine/pytorch/attention.py", line 247, in get_cu_seqlens
[rank0]:     cu_seqlens = torch.cat((zero, cu_seqlens))
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Tensors must have same number of dimensions: got 1 and 2

I'm using transformer_engine==1.8.0+3ec998e with megatron core_v0.7.0. My pre-training script is like this:

#!/bin/bash

export CUDA_DEVICE_MAX_CONNECTIONS=1
export NVTE_FLASH_ATTN=0

CHECKPOINT_PATH="./checkpoints"
VOCAB_FILE="./vocabs/bert-base-cased-vocab.txt"
DATA_PATH="my-bert_text_sentence"

BERT_ARGS="
    --num-layers 24 \
    --hidden-size 1024 \
    --num-attention-heads 16 \
    --seq-length 512 \
    --max-position-embeddings 512 \
    --micro-batch-size 4 \
    --global-batch-size 8 \
    --lr 0.0001 \
    --train-iters 2000000 \
    --lr-decay-iters 990000 \
    --lr-decay-style linear \
    --min-lr 0.00001 \
    --weight-decay 1e-2 \
    --lr-warmup-fraction .01 \
    --clip-grad 1.0 \
    --fp16
    --use-mcore-models
"

DATA_ARGS="
    --data-path $DATA_PATH \
    --vocab-file $VOCAB_FILE \
    --split 949,50,1
"

OUTPUT_ARGS="
    --log-interval 100 \
    --save-interval 10000 \
    --eval-interval 1000 \
    --eval-iters 10
"

torchrun pretrain_bert.py \
    $BERT_ARGS \
    $DATA_ARGS \
    $OUTPUT_ARGS \
    --save $CHECKPOINT_PATH \
    --load $CHECKPOINT_PATH

I'm also interested in using the BERT cased checkpoint instead of pre-training from scratch.

Thanks in advance.

@ksivaman ksivaman self-assigned this Jul 25, 2024
@sbhavani
Copy link
Contributor

sbhavani commented Sep 5, 2024

@fabiancpl sorry that the Megatron-LM example is not working for you, we'll look into it.

If you are interested in fine-tuning an existing checkpoint, TE has an integration with accelerate and we have an example for "bert-base-cased" here: https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants