Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[util] Add generic torch device class #6174

Merged
merged 19 commits into from
Apr 15, 2024
Merged

Conversation

lstein
Copy link
Collaborator

@lstein lstein commented Apr 7, 2024

Summary

This PR cleans up the various calls to choose_torch_device(), torch_dtype(), choose_precision() and so on, removes redundant function calls, and creates a single class named TorchDeviceSelect that supersedes their functionality. In addition to creating a simplified API, this class generalizes the call to clear the VRAM cache so that the same method empties the VRAM cache for both CUDA and MPS devices. It also provides an API for invocation context-dependent retrieval of the GPU device, intended for use in GPU load balancing in the future.

Example usage:

from invokeai.backend.util.devices import TorchDevice

class Foo(BaseInvocation):
    def invoke(self, context: InvocationContext):
         torch_device = TorchDevice.choose_torch_device()
         torch_dtype = TorchDevice.choose_torch_dtype()

         # empty CUDA or MPS cache
         TorchDevice.empty_cache()

The methods that return strings instead of objects, e.g. "float32" rather than torch.float32, have been removed. The legacy calls to choose_torch_device() and choose_precision() functions now issue a deprecation warning.

Related Issues / Discussions

QA Instructions

Merge Plan

Squash merge when approved.

Checklist

  • The PR has a short but descriptive title, suitable for a changelog
  • Tests added / updated (if applicable)
  • Documentation added / updated (if applicable)

@github-actions github-actions bot added python PRs that change python files invocations PRs that change invocations backend PRs that change backend files services PRs that change app services python-tests PRs that change python tests labels Apr 7, 2024
@lstein lstein marked this pull request as draft April 7, 2024 21:01
@lstein lstein force-pushed the lstein/feat/device-abstraction branch from 7898699 to 577bf62 Compare April 7, 2024 21:06
@lstein lstein marked this pull request as ready for review April 7, 2024 21:15
invokeai/backend/util/devices.py Outdated Show resolved Hide resolved
invokeai/backend/util/devices.py Outdated Show resolved Hide resolved
invokeai/backend/util/devices.py Outdated Show resolved Hide resolved
invokeai/backend/util/devices.py Outdated Show resolved Hide resolved
invokeai/backend/util/devices.py Outdated Show resolved Hide resolved
tests/backend/util/test_devices.py Outdated Show resolved Hide resolved
invokeai/app/invocations/latent.py Outdated Show resolved Hide resolved
@lstein lstein requested a review from RyanJDick April 10, 2024 22:53
@psychedelicious
Copy link
Collaborator

What's the difference between "auto" and "autocast"?

@lstein
Copy link
Collaborator Author

lstein commented Apr 11, 2024

What's the difference between "auto" and "autocast"?

"auto" selects one of the floating point precision types. "autocast" used to activate the torch.autocast() context for certain generation operations, but I this code has been removed. So I've gone ahead and removed references to this configuration option.

By the way, I notice a TODO from ryan in model_patcher.py saying that torch.autocast might provide a speed benefit in model patching. It could be hard-coded there if needed.

@psychedelicious
Copy link
Collaborator

"auto" selects one of the floating point precision types. "autocast" used to activate the torch.autocast() context for certain generation operations, but I this code has been removed. So I've gone ahead and removed references to this configuration option.

Gotcha. We need to be careful with changing valid config settings. If somebody has "autocast" in their invokeai.yaml file, they will get a pydantic error on startup. It may be OK for this particular setting, but generally we'll want to write a config script migration, like config_default.py:migrate_v3_config_dict

@lstein
Copy link
Collaborator Author

lstein commented Apr 11, 2024

Right. I'll put in a migrate script

Copy link
Collaborator

@RyanJDick RyanJDick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work on this PR. I love to see our test coverage going up! 😄

I've approved, but would still like some clarity on how you're thinking about the following:
1. As @psychedelicious mentioned this will break configs that contain autocast. Are we thinking that there are probably very few people with that setting? So we'll just provide guidance to anyone who does hit this?
Edit: I just saw your comment about addressing this.
2. How confident are we that context.models.get_execution_device() is the right API? I can imagine how it would enable multi-gpu - but without the full context of how it's going to be used I'm slightly nervous that we're going to end up making breaking changes to it.

@psychedelicious
Copy link
Collaborator

Indeed we should be careful when adding to the public API, because we are promising to support and maintain it.

Some more thoughts:

  • I'm not sure get_execution_device should be in the context.models namespace. Feels more like a utility to me.
  • There's no indication that using TorchDeviceSelect.get_execution_device() directly in a node is going to cause problems - because it doesn't except in a separate branch that may be merged in the future and not work for some nodes. In fact, some nodes do use TorchDeviceSelect directly. It's a confusing and has footgun potential.
  • "TorchDeviceSelect" suggests to me that I can use this class to select a torch device. Maybe something more generic like "TorchHelper" is clearer.
  • The invocation_api module exports some objects that are removed in this PR:
    from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device
    For better or worse, we have committed these to the public API. Need to handle that.

@lstein lstein marked this pull request as draft April 12, 2024 05:36
@lstein
Copy link
Collaborator Author

lstein commented Apr 12, 2024

I'm tired of working on this and will get back to it at some point soon. Converting to draft for now.

1. Remove TorchDeviceSelect.get_execution_device(), as well as calls to
   context.models.get_execution_device().
2. Rename TorchDeviceSelect to TorchDevice
3. Added back the legacy public API defined in `invocation_api`, including
   choose_precision().
4. Added a config file migration script to accommodate removal of precision=autocast.
@lstein
Copy link
Collaborator Author

lstein commented Apr 14, 2024

Indeed we should be careful when adding to the public API, because we are promising to support and maintain it.

Some more thoughts:

  • I'm not sure get_execution_device should be in the context.models namespace. Feels more like a utility to me.

  • There's no indication that using TorchDeviceSelect.get_execution_device() directly in a node is going to cause problems - because it doesn't except in a separate branch that may be merged in the future and not work for some nodes. In fact, some nodes do use TorchDeviceSelect directly. It's a confusing and has footgun potential.

  • "TorchDeviceSelect" suggests to me that I can use this class to select a torch device. Maybe something more generic like "TorchHelper" is clearer.

  • The invocation_api module exports some objects that are removed in this PR:

    from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device

    For better or worse, we have committed these to the public API. Need to handle that.

These recommendations have now been implemented.

@lstein lstein marked this pull request as ready for review April 14, 2024 21:27
@lstein lstein changed the title [util] Add object-oriented torch device class [util] Add generic torch device class Apr 14, 2024
@lstein lstein force-pushed the lstein/feat/device-abstraction branch from 48ce7b4 to 7e177c1 Compare April 14, 2024 22:54
Copy link
Collaborator

@psychedelicious psychedelicious left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the feedback, sorry if it was a hassle. Couple minor changes and comments

invokeai/backend/model_manager/merge.py Outdated Show resolved Hide resolved
invokeai/backend/util/devices.py Outdated Show resolved Hide resolved
invokeai/backend/util/devices.py Outdated Show resolved Hide resolved
@psychedelicious psychedelicious self-requested a review April 15, 2024 02:21
@lstein lstein enabled auto-merge (squash) April 15, 2024 13:03
@lstein lstein merged commit e93f4d6 into main Apr 15, 2024
14 checks passed
@lstein lstein deleted the lstein/feat/device-abstraction branch April 15, 2024 13:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend PRs that change backend files invocations PRs that change invocations python PRs that change python files python-tests PRs that change python tests services PRs that change app services
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants