Skip to content

Commit

Permalink
fix backend check (#2670)
Browse files Browse the repository at this point in the history
* fix backend check

* reformat backend check

* Update src/accelerate/state.py

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* Update src/accelerate/state.py

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* raise value error if backend mismatch

* Update src/accelerate/state.py

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
  • Loading branch information
jiqing-feng and muellerzr committed Apr 17, 2024
1 parent 39e0a8e commit fa0bd40
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def __init__(self, cpu: bool = False, **kwargs):
)

# Sets up self.backend + imports
backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, kwargs.pop("backend", None))
original_backend = kwargs.pop("backend", None)
backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
if original_backend is not None and backend != original_backend:
raise ValueError("Your assigned backend {original_backend} is not avaliable, please use {backend}")
self.backend = backend
self.distributed_type = distributed_type
use_deepspeed = False
Expand Down Expand Up @@ -718,41 +721,45 @@ def _prepare_backend(
elif is_torch_xla_available():
backend = "xla"
distributed_type = DistributedType.XLA
elif int(os.environ.get("LOCAL_RANK", -1)) != -1:
if not cpu:
if is_mlu_available():
backend = "cncl"
distributed_type = DistributedType.MULTI_MLU
elif torch.cuda.is_available():
if backend is None:
backend = "nccl"
distributed_type = DistributedType.MULTI_GPU
elif is_npu_available():
backend = "hccl"
distributed_type = DistributedType.MULTI_NPU
if backend is None and (
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
if is_mlu_available():
backend = "cncl"
distributed_type = DistributedType.MULTI_MLU
elif torch.cuda.is_available():
if backend is None:
backend = "nccl"
distributed_type = DistributedType.MULTI_GPU
elif is_npu_available():
backend = "hccl"
distributed_type = DistributedType.MULTI_NPU

if distributed_type is None and (
int(os.environ.get("LOCAL_RANK", -1)) != -1
or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1
):
if not cpu and is_xpu_available():
distributed_type = DistributedType.MULTI_XPU
else:
distributed_type = DistributedType.MULTI_CPU
if is_ccl_available() and (
get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU

if (
backend in (None, "ccl")
and is_ccl_available()
and (get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU)
):
if get_ccl_version() >= "1.12":
import oneccl_bindings_for_pytorch # noqa: F401
else:
import torch_ccl # noqa: F401

backend = "ccl"
elif torch.distributed.is_mpi_available():
elif backend in (None, "mpi") and torch.distributed.is_mpi_available():
backend = "mpi"
else:
backend = "gloo"
if distributed_type is None:
distributed_type = DistributedType.NO

return backend, distributed_type

def set_device(self):
Expand Down

0 comments on commit fa0bd40

Please sign in to comment.