Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix backend check #2670

Merged
merged 6 commits into from
Apr 17, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def __init__(self, cpu: bool = False, **kwargs):
)

# Sets up self.backend + imports
backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, kwargs.pop("backend", None))
original_backend = kwargs.pop("backend", None)
backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
if original_backend and backend != original_backend:
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(f"The assigned backend is {original_backend}, but the real backend is {backend}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it makes sense to do a warning here. This should be a raise if this is truly a situation we can get into.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this change propagated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I forgot to push. It should be updated now, pls check. Thx.

self.backend = backend
self.distributed_type = distributed_type
use_deepspeed = False
Expand Down Expand Up @@ -718,41 +721,45 @@ def _prepare_backend(
elif is_torch_xla_available():
backend = "xla"
distributed_type = DistributedType.XLA
elif int(os.environ.get("LOCAL_RANK", -1)) != -1:
if not cpu:
if is_mlu_available():
backend = "cncl"
distributed_type = DistributedType.MULTI_MLU
elif torch.cuda.is_available():
if backend is None:
backend = "nccl"
distributed_type = DistributedType.MULTI_GPU
elif is_npu_available():
backend = "hccl"
distributed_type = DistributedType.MULTI_NPU
if backend is None and (
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
if is_mlu_available():
backend = "cncl"
distributed_type = DistributedType.MULTI_MLU
elif torch.cuda.is_available():
if backend is None:
backend = "nccl"
distributed_type = DistributedType.MULTI_GPU
elif is_npu_available():
backend = "hccl"
distributed_type = DistributedType.MULTI_NPU

if distributed_type is None and (
int(os.environ.get("LOCAL_RANK", -1)) != -1
or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1
):
if not cpu and is_xpu_available():
distributed_type = DistributedType.MULTI_XPU
else:
distributed_type = DistributedType.MULTI_CPU
if is_ccl_available() and (
get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU

if (
(backend is None or backend == "ccl")
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
and is_ccl_available()
and (get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU)
):
if get_ccl_version() >= "1.12":
import oneccl_bindings_for_pytorch # noqa: F401
else:
import torch_ccl # noqa: F401

backend = "ccl"
elif torch.distributed.is_mpi_available():
elif (backend is None or backend == "mpi") and torch.distributed.is_mpi_available():
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
backend = "mpi"
else:
backend = "gloo"
if distributed_type is None:
distributed_type = DistributedType.NO

return backend, distributed_type

def set_device(self):
Expand Down
Loading