Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loss of 175B Megatron-LM doesn't convergence with memory-efficient-attention when dropout.p=0.1 and attn-bias=LowerTriangularMask #724

Closed
ZhangDY-6483 opened this issue Apr 12, 2023 · 6 comments

Comments

@ZhangDY-6483
Copy link

ZhangDY-6483 commented Apr 12, 2023

🐛 Bug

The loss doesn't convergence and the figure is shown in the attatched figure.

Command

To Reproduce

Steps to reproduce the behavior:

  1. I run 175B parameters on Megatron-LM and replace the its MultiHeadAttn into memory-efficient-attention (xformers v0.0.18).
  2. The loss was measured and results show the loss doesn't convergence.
  3. You can replace the Megatron-LM/megatron/model/transformer.py into the attatched file.

Expected behavior

transformers.txt

Loss has the same trend before replacement and after that.

Environment

Please copy and paste the output from the
environment collection script from PyTorch
(or fill out the checklist below manually).

You can run the script with:

# For security purposes, please check the contents of collect_env.py before running it.
![meg_dp1_dp1](https://user-images.githubusercontent.com/64682152/231338310-0ed46566-ef84-4fea-b6fc-5326e7a7c20c.jpeg)

PyTorch version: 1.13.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (GCC) 8.2.0
Clang version: 3.8.0 (tags/RELEASE_380/final)
CMake version: version 3.26.1
Libc version: glibc-2.26

Python version: 3.7.13 (default, Apr 24 2022, 01:04:09) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.14.0_1-0-0-44-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 470.82.01
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.4.1
/usr/lib/libcudnn_adv_infer.so.8.4.1
/usr/lib/libcudnn_adv_train.so.8.4.1
/usr/lib/libcudnn_cnn_infer.so.8.4.1
/usr/lib/libcudnn_cnn_train.so.8.4.1
/usr/lib/libcudnn_ops_infer.so.8.4.1
/usr/lib/libcudnn_ops_train.so.8.4.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.21.6
[pip3] torch==1.13.1
[pip3] torchaudio==0.13.1+cu117
[pip3] torchvision==0.14.1+cu117
[conda] Could not collect

Additional context

@danthe3rd
Copy link
Contributor

Hi @ZhangDY-6483
I managed to reproduce the issue on my end and working on it. It's specific to the CUTLASS implementation.
As a workaround for now, you can force usage of Flash-Attention [1], by adding an extra op parameter to xformers.ops.memory_efficient_attention

xformers.ops.memory_efficient_attention(..., op=xformers.ops.fmha.MemoryEfficientAttentionFlashAttentionOp)

[1] Only works because you are on A100, and if you are on f16/bf16 and don't use a torch.Tensor as your attention bias

@ZhangDY-6483
Copy link
Author

Hi @ZhangDY-6483 I managed to reproduce the issue on my end and working on it. It's specific to the CUTLASS implementation. As a workaround for now, you can force usage of Flash-Attention [1], by adding an extra op parameter to xformers.ops.memory_efficient_attention

xformers.ops.memory_efficient_attention(..., op=xformers.ops.fmha.MemoryEfficientAttentionFlashAttentionOp)

[1] Only works because you are on A100, and if you are on f16/bf16 and don't use a torch.Tensor as your attention bias

Hi @ZhangDY-6483 I managed to reproduce the issue on my end and working on it. It's specific to the CUTLASS implementation. As a workaround for now, you can force usage of Flash-Attention [1], by adding an extra op parameter to xformers.ops.memory_efficient_attention

xformers.ops.memory_efficient_attention(..., op=xformers.ops.fmha.MemoryEfficientAttentionFlashAttentionOp)

[1] Only works because you are on A100, and if you are on f16/bf16 and don't use a torch.Tensor as your attention bias

Thanks, and can I use memory efficient attention on fp32 and pipeline parallel (no date/model parallel)?

@danthe3rd
Copy link
Contributor

Flash will only work with f16/bf16 unfortunately. I'm not exactly sure what you mean with pipeline parallel - this should not interfere with the attention part I think

@ZhangDY-6483
Copy link
Author

Sorry, I mean the pipeline parallelism which is a strategy to train large models on multiple GPUs.(Ref 1. PIPELINE PARALLELISM; Ref 2. https://arxiv.org/abs/1811.06965)
In my case, fp32 is needed which is the reason I must use xformers.ops.memory_efficient_attention. And I want to run Megatron-LM with "core-attn" replaced by xformers.ops.memory_efficient_attention in MultiHeadAttention (https://github.com/NVIDIA/Megatron-LM.git) and the "pipeline-model-parallel-size" set to large than one. Did you ever try it ?

@danthe3rd
Copy link
Contributor

I've got a fix for the dropout issue. I need to do some more testing, but should be available next week hopefully :)

@danthe3rd
Copy link
Contributor

It should be fixed as of 70161e5, and will be included in the next release (0.0.19). In the meantime, you can also use a development build >=0.0.19.dev516

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants