diff --git a/tsl/platform/BUILD b/tsl/platform/BUILD index eb208e6ef..9f3b13524 100644 --- a/tsl/platform/BUILD +++ b/tsl/platform/BUILD @@ -985,6 +985,7 @@ cc_library( deps = [ "@ml_dtypes//:float8", "@ml_dtypes//:intn", + "@ml_dtypes//:mxfloat", ], ) diff --git a/tsl/platform/ml_dtypes.h b/tsl/platform/ml_dtypes.h index a6a1b56af..a03fa0244 100644 --- a/tsl/platform/ml_dtypes.h +++ b/tsl/platform/ml_dtypes.h @@ -18,8 +18,10 @@ limitations under the License. #include "ml_dtypes/include/float8.h" // from @ml_dtypes #include "ml_dtypes/include/intn.h" // from @ml_dtypes +#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes namespace tsl { +using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn; using float8_e3m4 = ::ml_dtypes::float8_e3m4; using float8_e4m3 = ::ml_dtypes::float8_e4m3; using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; @@ -27,6 +29,7 @@ using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; using float8_e5m2 = ::ml_dtypes::float8_e5m2; using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz; +using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu; using int1 = ::ml_dtypes::int1; using uint1 = ::ml_dtypes::uint1;