diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 95ef5af34e0c..2b093a85002d 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -93,6 +93,7 @@ def type(self) -> type: ... # TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 float8_e3m4: type[np.generic] | None = None float8_e4m3: type[np.generic] | None = None +float8_e8m0fnu: type[np.generic] | None = None float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz @@ -101,6 +102,7 @@ def type(self) -> type: ... _float8_e3m4_dtype: np.dtype | None = None _float8_e4m3_dtype: np.dtype | None = None +_float8_e8m0fnu_dtype: np.dtype | None = None _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -155,6 +157,12 @@ def supports_inf(dtype: DTypeLike) -> bool: _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] _custom_float_dtypes.insert(0, _float8_e3m4_dtype) _float8_dtypes.insert(0, _float8_e3m4_dtype) +if hasattr(ml_dtypes, "float8_e8m0fnu"): + float8_e8m0fnu = ml_dtypes.float8_e8m0fnu + _float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu) + _custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype) + _float8_dtypes.insert(0, _float8_e8m0fnu_dtype) # 2-bit integer support int2: type[np.generic] | None = None diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index b71b377d8999..a086562526dd 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -74,6 +74,7 @@ enum DType: byte { f8_e4m3fnuz = 19, f8_e5m2 = 20, f8_e5m2fnuz = 21, + f8_e8m0fnu = 25, } table AbstractValue { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 0d9ce961b556..7707670f1f82 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -363,7 +363,8 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 if dtypes._float8_e4m3_dtype is not None: _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 - +if dtypes._float8_e8m0fnu_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() } diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index 70d298020961..69092cd7edcd 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -61,6 +61,7 @@ class DType(object): f8_e5m2 = 20 f8_e5m2fnuz = 21 f0 = 22 + f8_e8m0fnu = 25 class ShardingKind(object): diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 97f91555cb05..639a81dac1b1 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -192,6 +192,8 @@ def _is_ir_values(x: IrValues) -> bool: _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get if dtypes.float8_e4m3 is not None: _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get +if dtypes.float8_e8m0fnu is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index cae7a762e313..105d9be6f685 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1021,6 +1021,8 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] if dtypes.float8_e4m3 is not None: fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] + if dtypes.float8_e8m0fnu is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -3764,6 +3766,8 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes += (dtypes.float8_e3m4,) if dtypes.float8_e4m3 is not None: fp8_dtypes += (dtypes.float8_e4m3,) + if dtypes.float8_e8m0fnu is not None: + fp8_dtypes += (dtypes.float8_e8m0fnu,) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes del preferred_element_type # Implied by the output aval lhs_aval, rhs_aval = ctx.avals_in diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 9aa131420b5f..d64fd5ba9efc 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -225,6 +225,8 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) if dtypes.float8_e4m3 is not None: float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) +if dtypes.float8_e8m0fnu is not None: + float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 6bbcdd08471f..93a6c29c2669 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -97,6 +97,9 @@ def default_tolerance(): if _dtypes.float8_e4m3 is not None: _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 +if _dtypes.float8_e8m0fnu is not None: + _default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 + default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -119,6 +122,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): custom_float_dtypes.insert(0, _dtypes.float8_e4m3) if _dtypes.float8_e3m4 is not None: custom_float_dtypes.insert(0, _dtypes.float8_e3m4) + if _dtypes.float8_e8m0fnu is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu) def maybe_upcast(x): if x.dtype in custom_float_dtypes: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 4e2a4f4fdeee..e3970a9c78f8 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1385,6 +1385,8 @@ def custom_floats(self): float_dtypes += [_dtypes.float8_e3m4] if _dtypes.float8_e4m3 is not None: float_dtypes += [_dtypes.float8_e4m3] + if _dtypes.float8_e8m0fnu is not None: + float_dtypes += [_dtypes.float8_e8m0fnu] return self.supported(float_dtypes) @_cached_property diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index c447b0844892..6873e5b7c14b 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -281,6 +281,7 @@ from jax._src.numpy.lax_numpy import ( float8_e3m4 as float8_e3m4, float8_e4m3 as float8_e4m3, + float8_e8m0fnu as float8_e8m0fnu, ) except ImportError: pass diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 6c7e9e3ab712..f0dacdf55a4c 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -68,6 +68,8 @@ fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] if dtypes.float8_e4m3 is not None: fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] +if dtypes.float8_e8m0fnu is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes