Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float8_e8m0fnu type support #25116

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

wenscarl
Copy link
Contributor

@wenscarl wenscarl commented Nov 26, 2024

This PR adds E8M0fnu type support.
E8M0fnu is a OpenCompute MX scale format, which has the following properties:

Unsigned format
8 exponent bits
Exponent range from -127 to 127
No zero and infinity
Single NaN value (0xFF).

@jakevdp

Smoke test

import jax
import jax.numpy as jnp

def foo(a):
   return jax.lax.bitcast_convert_type(a, new_dtype=jnp.float8_e8m0fnu)
a = jnp.ones((2,2),dtype=jnp.float8_e4m3fn)
foo_jit = jax.jit(foo)

# StableHLO
print(foo_jit.lower(a).as_text("stablehlo"))
# HLO
print(foo_jit.lower(a).as_text("hlo"))
# HLO
print(foo_jit.lower(a).compile().as_text())

Seeing error

jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Failed to serialize StableHLO;

Detailed error from MLIR: <unknown>:0: error: failed to legalize operation 'vhlo.func_v1' that was explicitly marked illegal
<unknown>:0: note: see current operation:
"vhlo.func_v1"() <{arg_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{}>]>, function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.tensor_v1<256x4096x1024x!vhlo.f8E4M3FN_v1>) -> !vhlo.tensor_v1<256x4096x1024x!vhlo.f8E8M0FNU_v1>>>, res_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"jax.result_info"> = #vhlo.string_v1<"">}>]>, sym_name = #vhlo.string_v1<"main">, sym_visibility = #vhlo.string_v1<"public">}> ({
^bb0(%arg0: !vhlo.tensor_v1<256x4096x1024x!vhlo.f8E4M3FN_v1>):
  %0 = "vhlo.bitcast_convert_v1"(%arg0) : (!vhlo.tensor_v1<256x4096x1024x!vhlo.f8E4M3FN_v1>) -> !vhlo.tensor_v1<256x4096x1024x!vhlo.f8E8M0FNU_v1>
  "vhlo.return_v1"(%0) : (!vhlo.tensor_v1<256x4096x1024x!vhlo.f8E8M0FNU_v1>) -> ()
}) : () -> ()

@wenscarl wenscarl marked this pull request as draft November 26, 2024 17:01
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 2, 2024

Thanks! Note that float8_e8m0fnu was added to ml_dtypes in version 0.5.0 (see the Change log) and JAX still must maintain compatibility with v0.4.0, due to the latest release of tensorflow pinning ml_dtypes<0.5.0.

Given that, you'll have to account for the ml_dtypes version in defining this. See #23585 for an example of how to do this conditional dtype definition.

@jakevdp jakevdp self-assigned this Dec 2, 2024
@wenscarl wenscarl marked this pull request as ready for review December 11, 2024 21:31
@wenscarl
Copy link
Contributor Author

gently ping @jakevdp

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Dec 18, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 18, 2024

Could you please rebase on the main branch to address test failures? Thanks!

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 18, 2024

CI failures look real here – is it possible that there's a minimum required jaxlib version for support of this new dtype?

@wenscarl
Copy link
Contributor Author

CI failures look real here – is it possible that there's a minimum required jaxlib version for support of this new dtype?

Yes. The failure is caused by XLA doesn't recognize e8m0fnu yet. Corresponding XLA PR in preparation. Do you suggest to hold on merge this PR until XLA is ready to remove the failed test for now?

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 19, 2024

Do you suggest to hold on merge this PR until XLA is ready to remove the failed test for now?

I'm not sure – @hawkinsp do you have thoughts here? Should we merge a dtype that XLA doesn't yet support?

@wenscarl
Copy link
Contributor Author

Do you suggest to hold on merge this PR until XLA is ready to remove the failed test for now?

I'm not sure – @hawkinsp do you have thoughts here? Should we merge a dtype that XLA doesn't yet support?

openxla/xla#19096 has been merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants