Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update modeling.py by adding try-catch section to skip the unavailable devices #2681

Merged
merged 2 commits into from
May 6, 2024
Merged

Update modeling.py by adding try-catch section to skip the unavailable devices #2681

merged 2 commits into from
May 6, 2024

Conversation

MeVeryHandsome
Copy link
Contributor

@MeVeryHandsome MeVeryHandsome commented Apr 17, 2024

What does this PR do?

In the current implementation of the library, when deploying tensor operations across multiple GPU devices, the program attempts to initialize a tensor on each device and get maximum available memory.

def get_max_memory(max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None):
        ...
                for i in range(torch.cuda.device_count()):
                    _ = torch.tensor([0], device=i)
                max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())}

However, this approach leads to an error and terminates the program if it encounters a GPU that is unavailable(memory fully occupied or device damaged).

For example, this error will occur when one of the devices is fully occupied:

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.

To address this issue, I propose a modification to the logic for determining the maximum available memory on GPUs. Specifically, I have introduced a try-except block around the tensor deployment operation.

            for i in range(torch.cuda.device_count()):
                try:
                    _ = torch.tensor([0], device=i)
                    max_memory.append({i: torch.cuda.mem_get_info(i)[0]})
                except Exception:
                    logger.warning(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
                    continue

If an error occurs due to the unavailability of a GPU, the code will now catch the exception and continue checking the next available GPU. This change ensures that the program gracefully skips over unavailable GPUs and only utilizes those that are operational.

This adjustment is intended to improve the library's usability in multi-devices. I am open to any suggestions or further improvements from the community.

Thank you for watching this : )

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

I apologize, as I am uncertain about whom to tag. Perhaps @muellerzr or @pacman100 can review this. thx

@MeVeryHandsome
Copy link
Contributor Author

MeVeryHandsome commented Apr 22, 2024

By the way, if the unavailable GPU is the first one among the visible devices, even when the above method is used to obtain the maximum available VRAM, an error will still occur in this section:

def set_module_tensor_to_device(
    module: nn.Module,
    tensor_name: str,
    device: Union[int, str, torch.device],
    value: Optional[torch.Tensor] = None,
    dtype: Optional[Union[str, torch.dtype]] = None,
    fp16_statistics: Optional[torch.HalfTensor] = None,
    tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
):
    ...
        torch.cuda.empty_cache()

This is because PyTorch's cache-clearing function, by default, affects the first device when no parameters are passed.

However, this is a feature of PyTorch, I personally believe it may not be necessary to optimize this aspect within this project. Therefore, one could manually implement a try-catch block to handle this situation:

        try:
            torch.cuda.empty_cache()
        except Exception:
            logger.warning("Cuda empty cache run failed, please check the first device to ensure it is available")

@MeVeryHandsome MeVeryHandsome changed the title Update modeling.py to ignore the unavailable devices Update modeling.py to skip the unavailable devices Apr 23, 2024
@MeVeryHandsome MeVeryHandsome changed the title Update modeling.py to skip the unavailable devices Update modeling.py by adding try-catch section to skip the unavailable devices Apr 23, 2024
@MeVeryHandsome

This comment was marked as resolved.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@muellerzr muellerzr requested a review from SunMarc April 24, 2024 14:03
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @MeVeryHandsome for the PR ! I think it is best to use info instead of warning. As for your question about empty_cache, we use it a lot in other parts of the library and it won't make sense to try catch these each time. This is probably something that pytorch should try to fix !

src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

Update src/accelerate/utils/modeling.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

Update src/accelerate/utils/modeling.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

Update src/accelerate/utils/modeling.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
@SunMarc SunMarc merged commit 11a3632 into huggingface:main May 6, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants