-
Notifications
You must be signed in to change notification settings - Fork 10.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apple silicon mps support #47
Apple silicon mps support #47
Conversation
There's a bit of duplication for the device selection logic since I didn't intend on PRing these changes, I'm not sure if you want to clean that up or not. 😅 |
scripts/img2img.py
Outdated
precision_scope = autocast if opt.precision == "autocast" else nullcontext | ||
if device.type == 'mps': | ||
precision_scope = nullcontext # have to use f32 on mps | ||
with torch.no_grad(): | ||
with precision_scope("cuda"): | ||
with precision_scope(device.type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is 90% of the way to being able to inference on CPU, however you will get a type mismatch error, expecting a BFloat16 but getting a Float.
To fix this, it should be more like if device.type in ['mps', 'cpu']:
because CPUs don't have BFloat16 operations either.
Alternatively, people can run the model with --precision full
when stuff doesn't work.
I am not sure if this is a bug in PyTorch or not, because my impression is that autocast()
should not be offering BFloat16 optimizations to anything that is not a recent Nvidia GPU.
Intel apparently has BF16 operations in recent Xeon products. Maybe this is a better way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for sharing the solution 🙇
I was running txt2img.py
and it kept failing with:
`RuntimeError: expected scalar type BFloat16 but found Float`
But adding --precision full
when running the script worked 🎉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for sharing the solution 🙇 I was running
txt2img.py
and it kept failing with:`RuntimeError: expected scalar type BFloat16 but found Float`
But adding
--precision full
when running the script worked 🎉
Wow this perfectly solves my same problem.
ldm/models/diffusion/ddim.py
Outdated
elif(torch.backends.mps.is_available()): | ||
self.device_available = "mps" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.backends.mps
is not available in all installations of PyTorch. This needs to be either checked or surrounded with a try-catch to avoid a crash.
environment-mac.yaml
Outdated
- numpy=1.19.2 | ||
- pip: | ||
- albumentations==0.4.3 | ||
- opencv-python==4.1.2.30 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, this opencv-python version (from November 2019) is longer offered on PIP
ERROR: Could not find a version that satisfies the requirement opencv-python==4.1.2.30 (from versions: 3.4.0.14, 3.4.10.37, 3.4.11.39, 3.4.11.41, 3.4.11.43, 3.4.11.45, 3.4.13.47, 3.4.15.55, 3.4.16.57, 3.4.16.59, 3.4.17.61, 3.4.17.63, 3.4.18.65, 4.3.0.38, 4.4.0.40, 4.4.0.42, 4.4.0.44, 4.4.0.46, 4.5.1.48, 4.5.3.56, 4.5.4.58, 4.5.4.60, 4.5.5.62, 4.5.5.64, 4.6.0.66)
environment-mac.yaml
Outdated
- pytorch | ||
- defaults | ||
dependencies: | ||
- python=3.8.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should be python=3.8.11
environment-mac.yaml
Outdated
- defaults | ||
dependencies: | ||
- python=3.8.5 | ||
- pip=20.3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current pip is more like 22.2.2
I closed this on accident (I didn't know that bringing my branch up-to-date would close it). I'm going to try to implement what @illeatmyhat and then resubmit it. |
It says you merged (0, zero) commits into CompVis:main. Hence this one would get closed. You might have wanted to use pull from main into magnusviri:apple-silicon-mps-support ? |
I don't have permission to merge into their repo and I don't see my changes in their repo but I have their changes in mine. The wording always confuses me, and I've never done what I did, so I didn't realize it would close the pull request. IDK. Anyway, they haven't done a merge since they released this 3 days ago. I've been pretty busy playing with this, so I haven't done much coding to improve my code yet. I am kind of waiting to see how the developers interact with the community before I fix the merge request (I've done a lot of work for other projects to have it flat out rejected, so I'm not too eager to work on it unless I know it will be merged). If anything, I might do a pull request for the lstein fork because he has responded. |
Updated ipynb file from py
feat(): add copy to img2img
I don't have access to the models so I haven't tested this. But I have tested the torch.device code and it does work.
To get this to work, the only thing that needs to be done differently is that the conda environment needs to use environment-mac.yaml instead of environment.yaml. That is because the cudatoolkit dependency in environment.yaml generates an error. I also believe (but haven't verified) that torch and torchvision need to be updated, so I updated them.