-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add float8_e8m0fnu type support #25116
base: main
Are you sure you want to change the base?
Conversation
Thanks! Note that Given that, you'll have to account for the ml_dtypes version in defining this. See #23585 for an example of how to do this conditional dtype definition. |
gently ping @jakevdp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
Could you please rebase on the main branch to address test failures? Thanks! |
ddbf557
to
7577ac9
Compare
CI failures look real here – is it possible that there's a minimum required jaxlib version for support of this new dtype? |
Yes. The failure is caused by XLA doesn't recognize e8m0fnu yet. Corresponding XLA PR in preparation. Do you suggest to hold on merge this PR until XLA is ready to remove the failed test for now? |
I'm not sure – @hawkinsp do you have thoughts here? Should we merge a dtype that XLA doesn't yet support? |
openxla/xla#19096 has been merged. |
This PR adds E8M0fnu type support.
E8M0fnu is a OpenCompute MX scale format, which has the following properties:
Unsigned format
8 exponent bits
Exponent range from -127 to 127
No zero and infinity
Single NaN value (0xFF).
@jakevdp
Smoke test
Seeing error