Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace 1F1B with ZB-H1 #93

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Replace 1F1B with ZB-H1 #93

wants to merge 1 commit into from

Conversation

QPHutu
Copy link

@QPHutu QPHutu commented Jan 22, 2024

The change is a quick implementation to replace 1F1B with ZB-H1 proposed in Zero Bubble Pipeline Parallelism, which reduces the bubbles in pipeline parallelism.

@QPHutu
Copy link
Author

QPHutu commented Jan 22, 2024

The paper been accepted by ICLR 2024.

The key idea is to split the backward computation into two parts, one that computes gradient for the input and another that
computes for the parameters. By rescheduling the parameters' gradient computation, we can have get a better efficiency without scrificing anything.

image image

@Dylancer1998
Copy link

May I ask what led you to commit to this repository over the original one? Just curious about your thoughts! @QPHutu

@QPHutu
Copy link
Author

QPHutu commented Jan 30, 2024

Thanks for the reply. There are 2 main reasons.

  1. We have one internal team using this repo to train LLM. So to better support their training, we decide to merge this commit to upstream.
  2. We also have plans to merge our new scheduling methods to the original Megatron, not only ZB-H1, but also all other schedulers. However, the whole code changes are quite complicated, so both us and Nvidia want to be careful about that. To make it simpler, we want to push ZB-H1 to the community first.

@martinjaggi
Copy link
Contributor

thanks for the PR!

for merging we'd like to understand the impact a bit better. did you verify how model parallel training of the current models supported here (such as llama2) is impacted by your change? (in terms of speed, stability and also verify model behavior is unchanged?)

indeed could be nice to also hear the feedback from the Nvidia/Megatron-LM team if you get a chance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants