-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make tests succeed more on MPS #1463
Conversation
cc @BenjaminBossan :) |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @akx ! We should be able to merge after fixing the styling checks ! 🙏 Can you format your changes ?
@younesbelkada Done... Can you maybe take a peek at #1467 so I don't need to manually remember to run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making these tests device-agnostic and addressing follow up enhancements on our testing suite ! 🤩
|
||
|
||
class PeftAutoModelTester(unittest.TestCase): | ||
dtype = torch.float16 if infer_device() == "mps" else torch.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to check if torch.cuda.is_bf16_supported()
instead or does that not work on Mac?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't trust anything that says cuda
on the tin to work with Macs 😁
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe that could be held off for later, anyway?
Thanks @akx for helping out all the Mac users out there :) |
Some tests would not succeed on MPS devices because
bfloat16
is not supported. This makes them get skipped instead.On a similar note, tests that are not applicable (e.g. generated via a
parametrized
matrix and should have been skipped) no longer seem to quietly pass, but are instead markedskip
.It also fixes up the rtype for
infer_device
to always be a string – previously, it would return a string, or in the case ofmps
, atorch.device()
.Sibling of #1448, since I got frustrated trying to run tests on my Apple Silicon device and figuring which ones of them are actual failures.