You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I found that this issue occurred because the device_id parameter was set when calling the torch.distributed.init_process_group function. This caused processes that were not members of the new_group to use ncclCommSplit to create a child communicator. However, for processes that were members of the new_group, the communicator was not created until the first communication operation was called, causing the non-new_group member processes to keep waiting on the parent communicator for other processes to create the communicator. In other words, this violates the convention that all ranks must call ncclCommSplit on the original communicator.
# in megatron/training/initialize.py: 431L (_initialize_distributed function)ifpackaging.version.Version(torch.__version__) >=packaging.version.Version("2.3.0"):
init_process_group_kwargs['device_id'] =device_id
Proposed fix
Before this issue is fixed in PyTorch, I think we should avoid passing the device_id parameter to torch.distributed.init_process_group, thereby falling back to the original behavior.
The text was updated successfully, but these errors were encountered:
Describe the bug
When I used pipeline parallelism to train a 7B GPT model, the program hung.
To Reproduce
I used the following script for training:
I found that this issue occurred because the device_id parameter was set when calling the torch.distributed.init_process_group function. This caused processes that were not members of the new_group to use ncclCommSplit to create a child communicator. However, for processes that were members of the new_group, the communicator was not created until the first communication operation was called, causing the non-new_group member processes to keep waiting on the parent communicator for other processes to create the communicator. In other words, this violates the convention that all ranks must call ncclCommSplit on the original communicator.
A detailed description can be found at pytorch/pytorch#134314.
Environment (please complete the following information):
Proposed fix
Before this issue is fixed in PyTorch, I think we should avoid passing the device_id parameter to torch.distributed.init_process_group, thereby falling back to the original behavior.
The text was updated successfully, but these errors were encountered: