Skip to content

Commit

Permalink
Revert "Use gloo as part of DeviceGPU's process group backend (#3509)" (
Browse files Browse the repository at this point in the history
#3523)

This reverts commit cccc8a7.

reverting
  • Loading branch information
snarayan21 authored Aug 6, 2024
1 parent b21c2a9 commit 3aa266f
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 10 deletions.
5 changes: 0 additions & 5 deletions composer/devices/device_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
import torch.backends.cudnn
import torch.cuda
import torch.cuda.amp
import torch.distributed as torch_dist
import torch.utils.data
from packaging import version

from composer.devices.device import Device
from composer.utils import dist
Expand Down Expand Up @@ -44,9 +42,6 @@ def __init__(
):
if not torch.cuda.is_available():
raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.')
if torch_dist.is_gloo_available() and version.parse(torch.__version__) >= version.parse('2.3.0'):
# Composer checkpoint load / save from before torch 2.3.0 is not compatible with gloo + nccl backends.
DeviceGPU.dist_backend = 'cuda:nccl,cpu:gloo'
if device_id is None:
device_id = dist.get_local_rank()
self._device = torch.device(f'cuda:{device_id}')
Expand Down
6 changes: 1 addition & 5 deletions tests/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import pytest
import torch
import torch.distributed as torch_dist
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

Expand Down Expand Up @@ -447,10 +446,7 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz
assert 'model_name' in metadata_sd

assert 'dist_backend' in metadata_sd
if torch_dist.is_gloo_available() and version.parse(torch.__version__) >= version.parse('2.3.0'):
assert metadata_sd['dist_backend'] == 'cuda:nccl,cpu:gloo'
else:
assert metadata_sd['dist_backend'] == 'nccl'
assert metadata_sd['dist_backend'] == 'nccl'


@pytest.mark.filterwarnings('ignore:SWA has')
Expand Down

0 comments on commit 3aa266f

Please sign in to comment.