-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[util] Add generic torch device class #6174
Conversation
7898699
to
577bf62
Compare
What's the difference between "auto" and "autocast"? |
invokeai/backend/model_manager/load/model_cache/model_cache_base.py
Outdated
Show resolved
Hide resolved
"auto" selects one of the floating point precision types. "autocast" used to activate the By the way, I notice a TODO from ryan in |
Gotcha. We need to be careful with changing valid config settings. If somebody has "autocast" in their |
Right. I'll put in a migrate script |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work on this PR. I love to see our test coverage going up! 😄
I've approved, but would still like some clarity on how you're thinking about the following:
1. As @psychedelicious mentioned this will break configs that contain autocast
. Are we thinking that there are probably very few people with that setting? So we'll just provide guidance to anyone who does hit this?
Edit: I just saw your comment about addressing this.
2. How confident are we that context.models.get_execution_device()
is the right API? I can imagine how it would enable multi-gpu - but without the full context of how it's going to be used I'm slightly nervous that we're going to end up making breaking changes to it.
Indeed we should be careful when adding to the public API, because we are promising to support and maintain it. Some more thoughts:
|
I'm tired of working on this and will get back to it at some point soon. Converting to draft for now. |
1. Remove TorchDeviceSelect.get_execution_device(), as well as calls to context.models.get_execution_device(). 2. Rename TorchDeviceSelect to TorchDevice 3. Added back the legacy public API defined in `invocation_api`, including choose_precision(). 4. Added a config file migration script to accommodate removal of precision=autocast.
These recommendations have now been implemented. |
48ce7b4
to
7e177c1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the feedback, sorry if it was a hassle. Couple minor changes and comments
Summary
This PR cleans up the various calls to
choose_torch_device()
,torch_dtype()
,choose_precision()
and so on, removes redundant function calls, and creates a single class namedTorchDeviceSelect
that supersedes their functionality. In addition to creating a simplified API, this class generalizes the call to clear the VRAM cache so that the same method empties the VRAM cache for both CUDA and MPS devices. It also provides an API for invocation context-dependent retrieval of the GPU device, intended for use in GPU load balancing in the future.Example usage:
The methods that return strings instead of objects, e.g. "float32" rather than
torch.float32
, have been removed. The legacy calls tochoose_torch_device()
andchoose_precision()
functions now issue a deprecation warning.Related Issues / Discussions
QA Instructions
Merge Plan
Squash merge when approved.
Checklist