diff --git a/third_party/tsl/third_party/py/ml_dtypes/e8m0.patch b/third_party/tsl/third_party/py/ml_dtypes/e8m0.patch new file mode 100644 index 0000000000000..add96113db620 --- /dev/null +++ b/third_party/tsl/third_party/py/ml_dtypes/e8m0.patch @@ -0,0 +1,99 @@ +diff --git a/include/float8.h b/include/float8.h +index 51d91fd..c53eb45 100644 +--- a/include/float8.h ++++ b/include/float8.h +@@ -1021,11 +1021,11 @@ struct numeric_limits_float8_e8m0fnu : public numeric_limits_float8_base { + static inline constexpr const int max_digits10 = + MaxDigits10FromDigits(digits); + // 2**-127 smallest valid normalized value.. +- static inline constexpr const int min_exponent = -127 + 1; ++ static inline constexpr const int min_exponent = -kExponentBias + 1; + static inline constexpr const int min_exponent10 = + MinExponent10FromMinExponent(min_exponent); + // 128 encoding using for NaN +- static inline constexpr const int max_exponent = 127; ++ static inline constexpr const int max_exponent = kExponentBias + 1; + static inline constexpr const int max_exponent10 = + MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); + static inline constexpr const bool is_iec559 = false; +@@ -1292,7 +1292,8 @@ struct Traits : public TraitsBase { + }; + + template +-constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff) { ++constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff, ++ bool use_implicit_bit) { + // Round to nearest even by adding a bias term. + // Consider a bit pattern + // FFF...FLRTT...T, +@@ -1301,9 +1302,12 @@ constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff) { + // - L is 1, R is 1, OR + // - L is 0, R is 1, any T is one. + // We do this by adding L to a bit pattern consisting of all T = 1. +- Bits bias = roundoff == 0 +- ? 0 +- : ((bits >> roundoff) & 1) + (Bits{1} << (roundoff - 1)) - 1; ++ // ++ // When rounding to zero mantissa (E8M0 type), the L bit is implicitly 1 (do ++ // not use the exponent bits for rounding). Add only the R bit in this case. ++ Bits bias = !use_implicit_bit ++ ? ((bits >> roundoff) & 1) + (Bits{1} << (roundoff - 1)) - 1 ++ : Bits{1} << (roundoff - 1); + return bits + bias; + } + +@@ -1443,6 +1447,7 @@ struct ConvertImpl> kFromMantissaBits; ++ const bool to_zero_mantissa = kToMantissaBits == 0; + + // `To` supports more exponents near zero which means that some subnormal + // values in `From` may become normal. +@@ -1473,11 +1478,14 @@ struct ConvertImpl 0) { ++ if constexpr (kDigitShift >= 0) { + bits <<= kDigitShift; + } else { + if constexpr (!kTruncate) { +- bits = RoundBitsToNearestEven(bits, -kDigitShift); ++ // When converting float to e8m0, the bits represent a denormal, ++ // so don't use the implicit mantissa bit for rounding. ++ bits = RoundBitsToNearestEven( ++ bits, -kDigitShift, to_zero_mantissa && kExponentOffset != 0); + } + bits >>= -kDigitShift; + } +@@ -1514,8 +1522,8 @@ struct ConvertImpl> exponent_shift; + } +@@ -1532,7 +1540,8 @@ struct ConvertImpl { + } + + if constexpr (!kTruncate) { +- from_bits = RoundBitsToNearestEven(from_bits, 8); ++ from_bits = RoundBitsToNearestEven(from_bits, 8, false); + // Rounding can cause an overflow to infinity. Clamp to the largest finite + // value if saturation is requested. + if constexpr (kSaturate) { diff --git a/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl b/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl index 0047319ecd918..1478f3cada677 100644 --- a/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl +++ b/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl @@ -12,6 +12,7 @@ def repo(): tf_http_archive( name = "ml_dtypes", build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD", + patch_file = ["//third_party/py/ml_dtypes:e8m0.patch"], link_files = { "//third_party/py/ml_dtypes:ml_dtypes.tests.BUILD": "tests/BUILD.bazel", "//third_party/py/ml_dtypes:LICENSE": "LICENSE", diff --git a/third_party/tsl/tsl/platform/ml_dtypes.h b/third_party/tsl/tsl/platform/ml_dtypes.h index 088398239d079..e1f9e373f669d 100644 --- a/third_party/tsl/tsl/platform/ml_dtypes.h +++ b/third_party/tsl/tsl/platform/ml_dtypes.h @@ -29,6 +29,7 @@ using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; using float8_e5m2 = ::ml_dtypes::float8_e5m2; using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz; +using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu; using int1 = ::ml_dtypes::int1; using uint1 = ::ml_dtypes::uint1; diff --git a/xla/array2d_test.cc b/xla/array2d_test.cc index 8c314edff2eb6..c62f6e882713e 100644 --- a/xla/array2d_test.cc +++ b/xla/array2d_test.cc @@ -233,6 +233,20 @@ TEST(Array2dTest, LinspaceF4E2M1FN) { EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 4.0); // 3.5 rounded up } +TEST(Array2dTest, LinspaceF8E8M0FNU) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 2.0); // 1.5 rounded up + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.0); // 2.5 rounded down + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 4.0); // 3.0 rounded up + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 4.0); // 3.5 rounded up +} + TEST(Array2dTest, Stringification) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); const std::string expected = R"([[1, 1.5], diff --git a/xla/backends/gpu/codegen/transforms/expand_float_ops.cc b/xla/backends/gpu/codegen/transforms/expand_float_ops.cc index a812282069928..90ef36bf37e08 100644 --- a/xla/backends/gpu/codegen/transforms/expand_float_ops.cc +++ b/xla/backends/gpu/codegen/transforms/expand_float_ops.cc @@ -163,7 +163,8 @@ int GetSignificandBits(mlir::FloatType ty) { } int GetExponentBias(mlir::FloatType ty) { - return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()); + return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()) - + ty.isFloat8E8M0FNU(); // No zero exponent for E8M0. } bool IsFNUZ(mlir::FloatType ty) { @@ -215,6 +216,8 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { return (bits & 0b0111'1111) == 0b0111'1111; } else if (ty.isFloat8E3M4()) { return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'0000); + } else if (ty.isFloat8E8M0FNU()) { + return bits == 0xFF; } return bits == 0x80; } @@ -294,6 +297,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, } else { wide_int_ty = b.getIntegerType( std::max(from_int_ty.getWidth(), to_int_ty.getWidth())); + // Avoid overflow for bit shifts. + auto may_overflow = [&](mlir::Type a, mlir::Type b) { + return a.isFloat8E8M0FNU() && b.isF16(); + }; + if (may_overflow(from_ty, to_ty) || may_overflow(to_ty, from_ty)) { + wide_int_ty = b.getI32Type(); + } } auto convert_int = [&](mlir::Type ty, Value v) -> Val { if (v.getType() == ty) { @@ -320,11 +330,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, }; // Shift bits to destination type, without sign bit. - Val from_sign_bit = from_bits.shrui(from_width - 1) != 0; - from_bits = from_bits & ((1ULL << (from_width - 1)) - 1); - - Value result_is_inf = IsInf(value, b); - Value input_is_nan = IsNaN(value, b); + Val from_sign_bit; + if (!from_ty.isFloat8E8M0FNU()) { + from_sign_bit = from_bits.shrui(from_width - 1) != 0; + from_bits = from_bits & ((1ULL << (from_width - 1)) - 1); + } auto cst_bits = [&](llvm::APFloat f) { return cst(b.getIntegerType(llvm::APFloat::getSizeInBits(f.getSemantics())), @@ -338,9 +348,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, if (to_ty.isFloat4E2M1FN()) { to_inf = cst_bits(llvm::APFloat::getLargest(to_ty.getFloatSemantics())); to_nan = to_zero | 0x8; + } else if (to_ty.isFloat8E8M0FNU()) { + to_inf = to_nan; + to_zero = Val{to_nan, &b}; } - auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) { + auto round_bits_to_nearest_even = [&](Val bits, Val roundoff, + bool use_implicit_bit = false) { assert(bits.value.getType() == roundoff.value.getType()); // Round to nearest even by adding a bias term. // Consider a bit pattern @@ -350,9 +364,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, // - L is 1, R is 1, OR // - L is 0, R is 1, any T is one. // We do this by adding L to a bit pattern consisting of all T = 1. - Val rounded = (bits.shrui(roundoff) & 1) + - (bits.MakeConstant(1).shl(roundoff - 1) - 1); - Val bias{b.create(roundoff == 0, roundoff, rounded), &b}; + Val bias = !use_implicit_bit + ? (bits.shrui(roundoff) & 1) + + (bits.MakeConstant(1).shl(roundoff - 1) - 1) + : bits.MakeConstant(1).shl(roundoff - 1); return bits + bias; }; @@ -362,9 +377,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, // Round the mantissa if it is shrinking. Val rounded_from_bits = convert_int(wide_int_ty, from_bits); if (digit_shift < 0) { - rounded_from_bits = round_bits_to_nearest_even( - from_bits, from_bits.MakeConstant(-digit_shift)) & - ~((1ll << (-digit_shift)) - 1); + rounded_from_bits = + round_bits_to_nearest_even( + rounded_from_bits, rounded_from_bits.MakeConstant(-digit_shift), + /*use_implicit_bit=*/to_mantissa == 0) & + ~((1ll << (-digit_shift)) - 1); } // Re-bias the exponent. @@ -431,10 +448,12 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, Value biased_exp_sle_zero = biased_exponent.cmp(CmpIPredicate::sle, 0); bits.value = b.create(biased_exp_sle_zero, subnormal_bits, normal_bits); - if (digit_shift > 0) { + if (digit_shift >= 0) { bits = bits.shl(digit_shift); } else { - bits = round_bits_to_nearest_even(bits, bits.MakeConstant(-digit_shift)); + bits = round_bits_to_nearest_even( + bits, bits.MakeConstant(-digit_shift), + /*use_implicit_bit=*/to_mantissa == 0 && exp_offset != 0); bits = bits.shrui(-digit_shift); } bits = convert_int(to_int_ty, bits); @@ -443,11 +462,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, } else if (to_min_exp > from_min_exp) { // `To` supports fewer exponents near zero which means that some values in // `From` may become subnormal. - Val unbiased_exp = biased_from_exp - from_bias; - Val biased_to_exp = unbiased_exp + to_bias; + Val biased_to_exp = biased_from_exp + (to_bias - from_bias); // Subnormals and zero. // Round and shift mantissa down. - Val from_has_leading_one = biased_from_exp != 0; + Val from_has_leading_one = + !from_ty.isFloat8E8M0FNU() ? biased_from_exp != 0 : cst(i32_ty, 1); Val from_has_leading_one_i32 = convert_int(i32_ty, from_has_leading_one); from_has_leading_one = convert_int(from_int_ty, from_has_leading_one); Val exponent_shift_i32 = @@ -482,7 +501,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, result); } - if (IsFNUZ(to_ty)) { + Value result_is_inf = IsInf(value, b); + Value input_is_nan = IsNaN(value, b); + + if (to_ty.isFloat8E8M0FNU()) { + // Converting a negative number to E8M0 results in NaN. + input_is_nan = from_sign_bit | input_is_nan; + } else if (IsFNUZ(to_ty)) { // Clear the sign bit if the result is zero (the output has no negative // zero). Handle the edge case when the input is zero and the result is not. Val result_is_non_zero = @@ -494,14 +519,17 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, from_sign_bit = from_sign_bit ^ input_is_nan; } + if (!from_ty.isFloat8E8M0FNU()) { + result = b.create(from_bits == 0, to_zero, result); + } result = b.create(result_is_inf, to_inf, result); - result = b.create(from_bits == 0, to_zero, result); result = b.create(input_is_nan, to_nan, result); - Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1)); - // Insert sign bit. - result = b.create(from_sign_bit, neg_result, result); + if (!from_ty.isFloat8E8M0FNU()) { + Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1)); + result = b.create(from_sign_bit, neg_result, result); + } result = b.create(to_ty, result); return result; } @@ -598,6 +626,14 @@ struct RewriteAbsFPattern : public mlir::OpRewritePattern { return rewriter.notifyMatchFailure(op, "not an f8 (or less) or bf16 absf"); } + + // If type is unsigned (E8M0), the operation is no-op. + if (!llvm::APFloat::semanticsHasSignedRepr( + src.getType().getFloatSemantics())) { + rewriter.replaceAllOpUsesWith(op, op.getOperand()); + return mlir::success(); + } + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth()); Val value{b.create(i_ty, src), &b}; diff --git a/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir b/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir index 69187e3e27486..dea8988d474b0 100644 --- a/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir +++ b/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir @@ -152,3 +152,16 @@ module { // CHECK-LABEL: @f4_abs // CHECK-NOT: math.absf // CHECK: arith.constant 7 : i4 + +// ----- + +module { + func.func @e8m0_abs(%arg0: f8E8M0FNU) -> f8E8M0FNU { + %ret = math.absf %arg0 : f8E8M0FNU + return %ret : f8E8M0FNU + } +} + +// CHECK-LABEL: @e8m0_abs +// CHECK-NOT: math.absf +// CHECK: return %arg0 diff --git a/xla/comparison_util.h b/xla/comparison_util.h index 5a21595da4d74..44f0dd48640bb 100644 --- a/xla/comparison_util.h +++ b/xla/comparison_util.h @@ -193,8 +193,13 @@ class Comparison { // -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN // Reference: // https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations - using R = SignedIntegerTypeForSizeType; - return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); + if constexpr (std::numeric_limits::is_signed) { + using R = SignedIntegerTypeForSizeType; + return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); + } else { + using R = UnsignedIntegerTypeForSizeType; + return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); + } } } // Applies the comparison from this Comparison's direction and ordering. diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h index af37257f56e14..9787476f8f7ea 100644 --- a/xla/ffi/api/api.h +++ b/xla/ffi/api/api.h @@ -147,6 +147,8 @@ inline std::ostream& operator<<(std::ostream& os, return os << "F8E5M2FNUZ"; case XLA_FFI_DataType_F8E4M3FNUZ: return os << "F8E4M3FNUZ"; + case XLA_FFI_DataType_F8E8M0FNU: + return os << "F8E8M0FNU"; } } diff --git a/xla/ffi/api/c_api.h b/xla/ffi/api/c_api.h index 96c6543462ba0..e344575e29ea9 100644 --- a/xla/ffi/api/c_api.h +++ b/xla/ffi/api/c_api.h @@ -202,6 +202,7 @@ typedef enum { XLA_FFI_DataType_F8E5M2FNUZ = 24, XLA_FFI_DataType_F8E4M3FNUZ = 25, XLA_FFI_DataType_F4E2M1FN = 30, + XLA_FFI_DataType_F8E8M0FNU = 31, } XLA_FFI_DataType; // LINT.ThenChange(ffi_test.cc) diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index b4602f0627725..aeeab1d505ab6 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -80,6 +80,7 @@ enum class DataType : uint8_t { F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ, F8E3M4 = XLA_FFI_DataType_F8E3M4, F4E2M1FN = XLA_FFI_DataType_F4E2M1FN, + F8E8M0FNU = XLA_FFI_DataType_F8E8M0FNU, }; // Create aliases in ::xla::ffi namespace for all DataTypes, for consistency @@ -108,6 +109,7 @@ inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ; inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ; inline constexpr DataType F8E3M4 = DataType::F8E3M4; inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN; +inline constexpr DataType F8E8M0FNU = DataType::F8E8M0FNU; inline std::ostream& operator<<(std::ostream& os, const DataType dtype) { return os << static_cast(dtype); @@ -130,6 +132,7 @@ constexpr size_t ByteWidth(DataType dtype) { case DataType::F8E4M3FNUZ: case DataType::F8E3M4: case DataType::F4E2M1FN: + case DataType::F8E8M0FNU: return 1; case DataType::S16: case DataType::U16: diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index fcac1e68d04fd..3c51a0966ae02 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -138,6 +138,7 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4)); + EXPECT_EQ(encoded(PrimitiveType::F8E8M0FNU), encoded(DataType::F8E8M0FNU)); } TEST(FfiTest, DataTypeByteWidth) { @@ -196,6 +197,8 @@ TEST(FfiTest, DataTypeByteWidth) { ByteWidth(DataType::F8E4M3FNUZ)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4), ByteWidth(DataType::F8E3M4)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E8M0FNU), + ByteWidth(DataType::F8E8M0FNU)); } TEST(FfiTest, ErrorEnumValue) { diff --git a/xla/ffi/call_frame.cc b/xla/ffi/call_frame.cc index c751fbe547bf4..7bcb14da445e8 100644 --- a/xla/ffi/call_frame.cc +++ b/xla/ffi/call_frame.cc @@ -272,6 +272,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::F8E5M2FNUZ: case PrimitiveType::F8E4M3FNUZ: case PrimitiveType::F8E3M4: + case PrimitiveType::F8E8M0FNU: return static_cast(primitive_type); default: DCHECK(false) << "Unsupported primitive type " diff --git a/xla/fp_util_test.cc b/xla/fp_util_test.cc index 09b66c2984b05..8ea22d9d1602b 100644 --- a/xla/fp_util_test.cc +++ b/xla/fp_util_test.cc @@ -172,6 +172,23 @@ TEST(FPDistanceTest, F4E2M1FNDistance) { 4); } +TEST(FPDistanceTest, F8E8M0FNUDistance) { + // a & b are equal + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(1.0)), + 0); + + // one step apart + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(2.0)), + 1); + + // two steps apart + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float8_e8m0fnu(0.5), tsl::float8_e8m0fnu(2.0)), + 2); +} + TEST(FPDistanceTest, F8E3M4Distance) { // a & b are equal EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), diff --git a/xla/hlo/evaluator/hlo_evaluator.cc b/xla/hlo/evaluator/hlo_evaluator.cc index 35fac878f104d..8e44243823c09 100644 --- a/xla/hlo/evaluator/hlo_evaluator.cc +++ b/xla/hlo/evaluator/hlo_evaluator.cc @@ -3722,7 +3722,7 @@ absl::StatusOr StochasticConvertOp(const Literal& operand_literal, const Shape& result_shape) { std::function stochastic_convert_op = [](Fp operand, Uint random) -> ResultT { - bool is_negative = static_cast(Eigen::numext::signbit(operand)); + bool is_negative = static_cast(SignAndMagnitude(operand).first); if (Eigen::numext::isinf(operand)) { return is_negative ? std::numeric_limits::min() : std::numeric_limits::max(); diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 0d1ebe15880ad..7f0925f1a3179 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1742,6 +1742,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc index f665306936d63..6bc96c1a1f7cd 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc @@ -19,4 +19,5 @@ limitations under the License. namespace xla { template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/hlo/transforms/expanders/comparison_expander.cc b/xla/hlo/transforms/expanders/comparison_expander.cc index 0f09ecced1eba..86d1eeafcd593 100644 --- a/xla/hlo/transforms/expanders/comparison_expander.cc +++ b/xla/hlo/transforms/expanders/comparison_expander.cc @@ -115,34 +115,41 @@ absl::StatusOr ComparisonExpander::ExpandInstruction( ShapeUtil::ChangeElementType(rhs->shape(), compare_type), rhs)); } - int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type()); - PrimitiveType signed_type = - primitive_util::SignedIntegralTypeForBitWidth(bit_width); - auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type); - - auto zero_value = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); - zero_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, zero_value, {})); - - auto min_value = computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MinValue(signed_shape.element_type()))); - min_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, min_value, {})); - - auto max_value = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); - max_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, max_value, {})); - - lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value, - min_value, max_value); - rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value, - min_value, max_value); + if (compare_type != F8E8M0FNU) { + int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type()); + PrimitiveType signed_type = + primitive_util::SignedIntegralTypeForBitWidth(bit_width); + auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type); + + auto zero_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); + zero_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, zero_value, {})); + + auto min_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MinValue(signed_type))); + min_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, min_value, {})); + + auto max_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); + max_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, max_value, {})); + + lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value, + min_value, max_value); + rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value, + min_value, max_value); + } else { + auto int8_shape = ShapeUtil::ChangeElementType(lhs->shape(), U8); + lhs = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(int8_shape, lhs)); + rhs = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(int8_shape, rhs)); + } auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare( - instruction->shape(), lhs, rhs, compare->direction(), - Comparison::Type::kSigned)); + instruction->shape(), lhs, rhs, compare->direction())); VLOG(2) << "New comparison instruction for total order:" << new_compare->ToString(); diff --git a/xla/hlo/transforms/simplifiers/float_normalization_test.cc b/xla/hlo/transforms/simplifiers/float_normalization_test.cc index 9017b6dc47b61..b614f74229c0e 100644 --- a/xla/hlo/transforms/simplifiers/float_normalization_test.cc +++ b/xla/hlo/transforms/simplifiers/float_normalization_test.cc @@ -152,7 +152,7 @@ class FloatNormalizationF8Test INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test, ::testing::Values(F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, - F8E5M2, F8E5M2FNUZ)); + F8E5M2, F8E5M2FNUZ, F8E8M0FNU)); TEST_F(FloatNormalizationTest, NoopIfSupported) { auto builder = HloComputation::Builder(TestName()); diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo index 915783991915e..577e4ad61f89e 100644 --- a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo +++ b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo @@ -424,6 +424,9 @@ add { // CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN> %constant.14 = f4e2m1fn[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_15:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00]> : tensor<4xf8E8M0FNU> + %constant.15 = f8e8m0fnu[4] constant({1, 2, 4, 8}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -551,7 +554,13 @@ add { %convert.17 = f4e2m1fn[4] convert(f32[4] %convert.16) // CHECK-NEXT: %15 = mhlo.convert %14 : (tensor<4xf4E2M1FN>) -> tensor<4xf32> - ROOT %convert.18 = f32[4] convert(f4e2m1fn[4] %convert.17) + %convert.18 = f32[4] convert(f4e2m1fn[4] %convert.17) + + // CHECK-NEXT: %16 = mhlo.convert %15 : (tensor<4xf32>) -> tensor<4xf8E8M0FNU> + %convert.19 = f8e8m0fnu[4] convert(f32[4] %convert.18) + + // CHECK-NEXT: %17 = mhlo.convert %16 : (tensor<4xf8E8M0FNU>) -> tensor<4xf32> + ROOT %convert.20 = f32[4] convert(f8e8m0fnu[4] %convert.19) } // CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8> diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir index 6f8f66d096668..c017751477cb5 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir @@ -609,6 +609,9 @@ func.func @main() { // CHECK: f4e2m1fn[4] constant({1, 2, 3, 4}) %cst_18 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN> + // CHECK: f8e8m0fnu[4] constant({1, 2, 4, 8}) + %cst_19 = arith.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00]> : tensor<4xf8E8M0FNU> + func.return } @@ -744,7 +747,9 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { %11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32> %12 = "mhlo.convert"(%11) : (tensor<2xf32>) -> tensor<2xf4E2M1FN> %13 = "mhlo.convert"(%12) : (tensor<2xf4E2M1FN>) -> tensor<2xf32> - func.return %13 : tensor<2xf32> + %14 = "mhlo.convert"(%13) : (tensor<2xf32>) -> tensor<2xf8E8M0FNU> + %15 = "mhlo.convert"(%14) : (tensor<2xf8E8M0FNU>) -> tensor<2xf32> + func.return %15 : tensor<2xf32> } // CHECK: ENTRY @@ -762,7 +767,9 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]]) // CHECK: %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) // CHECK: %[[E2M1FN_VAL:.*]] = f4e2m1fn[2] convert(f32[2] %[[F32_VAL6]]) -// CHECK: ROOT %[[F32_VAL7:.*]] = f32[2] convert(f4e2m1fn[2] %[[E2M1FN_VAL]]) +// CHECK: %[[F32_VAL7:.*]] = f32[2] convert(f4e2m1fn[2] %[[E2M1FN_VAL]]) +// CHECK: %[[E8M0FNU_VAL:.*]] = f8e8m0fnu[2] convert(f32[2] %[[F32_VAL7]]) +// CHECK: ROOT %[[F32_VAL8:.*]] = f32[2] convert(f8e8m0fnu[2] %[[E8M0FNU_VAL]]) // ----- diff --git a/xla/literal.cc b/xla/literal.cc index 37ca5843e9c63..866bc1838a919 100644 --- a/xla/literal.cc +++ b/xla/literal.cc @@ -95,9 +95,10 @@ bool LiteralProtoHasValues(const LiteralProto& proto) { !proto.f8e4m3b11fnuzs().empty() || !proto.f8e4m3fns().empty() || !proto.f8e4m3fnuzs().empty() || !proto.f8e4m3s().empty() || !proto.f8e5m2fnuzs().empty() || !proto.f8e5m2s().empty() || - !proto.f16s().empty() || !proto.bf16s().empty() || proto.f32s_size() || - proto.f64s_size() || proto.c64s_size() || proto.c128s_size() || - proto.preds_size() || proto.tuple_literals_size(); + !proto.f8e8m0fnus().empty() || !proto.f16s().empty() || + !proto.bf16s().empty() || proto.f32s_size() || proto.f64s_size() || + proto.c64s_size() || proto.c128s_size() || proto.preds_size() || + proto.tuple_literals_size(); } // Lazy getter for the interned scalar shape in static storage. We reuse this @@ -2298,6 +2299,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; + case F8E8M0FNU: + *proto->mutable_f8e8m0fnus() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F16: *proto->mutable_f16s() = std::string(reinterpret_cast(data().data()), @@ -2510,6 +2516,14 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { memcpy(untyped_data(), s.data(), s.size()); break; } + case F8E8M0FNU: { + const std::string& s(proto.f8e8m0fnus()); + TF_RET_CHECK(data().size() * + sizeof(tsl::float8_e8m0fnu) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F16: { const std::string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/xla/literal.h b/xla/literal.h index 34af5c4b275ed..cf544aa3a0c0c 100644 --- a/xla/literal.h +++ b/xla/literal.h @@ -591,7 +591,14 @@ class LiteralBase { if constexpr (bits_per_element < 8) { static_assert(!primitive_util::IsComplexType(primitive_type)); static_assert(8 % bits_per_element == 0); + constexpr int elements_per_byte = 8 / bits_per_element; + constexpr auto cast = [](NativeT x) -> uint8_t { + if constexpr (primitive_util::IsFloatingPointType(primitive_type)) { + return Eigen::numext::bit_cast(x); + } + return static_cast(x); + }; int64_t bytes = elements.size() / elements_per_byte; for (int64_t i = 0; i < bytes; ++i) { diff --git a/xla/literal_comparison.cc b/xla/literal_comparison.cc index c97629594122b..fee817978ec3e 100644 --- a/xla/literal_comparison.cc +++ b/xla/literal_comparison.cc @@ -418,6 +418,9 @@ class NearComparator { } else { float_distance = CalculateFloatDistance(expected, actual); abs_error = FpAbsoluteValue(actual - expected); + if (!std::numeric_limits::is_signed && IsNaN(abs_error)) { + abs_error = FpAbsoluteValue(expected - actual); + } // Avoid division by 0 even though it's well-defined because ubsan can be // configured to treat this as a fatal error. diff --git a/xla/literal_comparison_test.cc b/xla/literal_comparison_test.cc index ee83ecaf66687..ee8aafda2c350 100644 --- a/xla/literal_comparison_test.cc +++ b/xla/literal_comparison_test.cc @@ -29,10 +29,11 @@ namespace { template class LiteralComparisonTest : public ::testing::Test {}; -using TestedTypes = ::testing::Types; +using TestedTypes = + ::testing::Types; TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes); TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { @@ -53,6 +54,8 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { expV = 1.0625; else if (type == F4E2M1FN) expV = 1.5; + else if (type == F8E8M0FNU) + expV = 2.0; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, @@ -75,6 +78,8 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { expV = 1.25; else if (type == F4E2M1FN) expV = 4.0; + else if (type == F8E8M0FNU) + expV = 16.0; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; @@ -99,6 +104,8 @@ TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) { expV = 1.26; else if (type == F4E2M1FN) expV = 4.1; + else if (type == F8E8M0FNU) + expV = 16.5; auto expected = LiteralUtil::CreateR0(expV); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; diff --git a/xla/literal_test.cc b/xla/literal_test.cc index 3955835c5560f..7aa9f2dc040dc 100644 --- a/xla/literal_test.cc +++ b/xla/literal_test.cc @@ -124,10 +124,11 @@ class LiteralUtilTest : public ::testing::Test { template class LiteralUtilFloatTest : public LiteralUtilTest {}; -using FloatTypes = ::testing::Types< - float, half, bfloat16, tsl::float4_e2m1fn, tsl::float8_e3m4, - tsl::float8_e4m3, tsl::float8_e4m3b11fnuz, tsl::float8_e4m3fn, - tsl::float8_e4m3fnuz, tsl::float8_e5m2, tsl::float8_e5m2fnuz>; +using FloatTypes = ::testing::Types; TYPED_TEST_SUITE(LiteralUtilFloatTest, FloatTypes); @@ -210,6 +211,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f8e3m4_lit = LiteralUtil::CreateR0(tsl::float8_e3m4(0.5)); EXPECT_EQ("f8e3m4[] 0.5", f8e3m4_lit.ToString()); + + auto f8e8m0fnu_lit = + LiteralUtil::CreateR0(tsl::float8_e8m0fnu(0.5)); + EXPECT_EQ("f8e8m0fnu[] 0.5", f8e8m0fnu_lit.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -697,6 +702,11 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(LiteralUtil::CreateR1({v16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({v16}).IsAll(9)); + tsl::float8_e8m0fnu w16(8); + EXPECT_TRUE(LiteralUtil::CreateR1({w16}).IsAll(8)); + // 9 rounds to 8 in E8M0FNU but is not equal to 8, so this should be false + EXPECT_FALSE(LiteralUtil::CreateR1({w16}).IsAll(9)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); @@ -2245,6 +2255,9 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { LiteralUtil::CreateR1({e4f{10.0}, e4f{20.0}, e4f{-30.0}}); using e3 = tsl::float8_e3m4; auto vector_f8e3m4 = LiteralUtil::CreateR1({e3{2.5}, e3{5.0}, e3{-8.0}}); + using e8m0 = tsl::float8_e8m0fnu; + auto vector_f8e8m0fnu = + LiteralUtil::CreateR1({e8m0{1.0}, e8m0{2.0}, e8m0{4.0}}); auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto vector_s4 = LiteralUtil::CreateR1({s4{-1}, s4{3}, s4{7}}); @@ -2273,6 +2286,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_f8e4m3fnuz, to_from_proto(vector_f8e4m3fnuz)); EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); + EXPECT_EQ(vector_f8e8m0fnu, to_from_proto(vector_f8e8m0fnu)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(vector_s4, to_from_proto(vector_s4)); EXPECT_EQ(vector_u4, to_from_proto(vector_u4)); @@ -2523,8 +2537,8 @@ TEST_F(LiteralUtilTest, SliceOnBool) { } TEST_F(LiteralUtilTest, IsEqualAt) { - double val_double = 6.0; - int val_integral = 6; + double val_double = 4.0; + int val_integral = 4; Literal c1 = LiteralUtil::CreateR0(val_integral); EXPECT_TRUE(c1.IsEqualAt({}, val_double)); EXPECT_TRUE(c1.IsEqualAt({}, val_integral)); @@ -2573,6 +2587,10 @@ TEST_F(LiteralUtilTest, IsEqualAt) { LiteralUtil::CreateR0(tsl::float4_e2m1fn{val_double}); EXPECT_TRUE(c11.IsEqualAt({}, val_double)); EXPECT_TRUE(c11.IsEqualAt({}, val_integral)); + Literal c12 = LiteralUtil::CreateR0( + tsl::float8_e8m0fnu{val_double}); + EXPECT_TRUE(c12.IsEqualAt({}, val_double)); + EXPECT_TRUE(c12.IsEqualAt({}, val_integral)); } TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { @@ -2901,8 +2919,8 @@ class LiteralSerializationTest : public ::testing::Test, {PRED, S4, U4, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, BF16, F4E2M1FN, F8E3M4, F8E4M3, - F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, C64, - C128}) { + F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F8E8M0FNU, + C64, C128}) { for (const DimensionVector& dimensions : { DimensionVector{}, DimensionVector{0}, diff --git a/xla/mlir/utils/type_util.cc b/xla/mlir/utils/type_util.cc index 1b7ab6d08a9ac..ea8da4d4990d9 100644 --- a/xla/mlir/utils/type_util.cc +++ b/xla/mlir/utils/type_util.cc @@ -48,6 +48,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( return b.getFloat8E4M3FNUZType(); case xla::PrimitiveType::F8E3M4: return b.getFloat8E3M4Type(); + case xla::PrimitiveType::F8E8M0FNU: + return b.getFloat8E8M0FNUType(); case xla::PrimitiveType::F16: return b.getF16Type(); case xla::PrimitiveType::BF16: @@ -96,6 +98,8 @@ xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { return xla::PrimitiveType::F8E5M2FNUZ; } else if (type.isFloat8E3M4()) { return xla::PrimitiveType::F8E3M4; + } else if (type.isFloat8E8M0FNU()) { + return xla::PrimitiveType::F8E8M0FNU; } else if (type.isBF16()) { return xla::PrimitiveType::BF16; } else if (type.isF16()) { diff --git a/xla/mlir/utils/type_util_test.cc b/xla/mlir/utils/type_util_test.cc index 5f9d21e14cac1..2239943d906b7 100644 --- a/xla/mlir/utils/type_util_test.cc +++ b/xla/mlir/utils/type_util_test.cc @@ -112,6 +112,7 @@ INSTANTIATE_TEST_SUITE_P( {F8E4M3FNUZ, [](mlir::Builder b) { return b.getFloat8E4M3FNUZType(); }}, {F8E3M4, [](mlir::Builder b) { return b.getFloat8E3M4Type(); }}, + {F8E8M0FNU, [](mlir::Builder b) { return b.getFloat8E8M0FNUType(); }}, {F16, [](mlir::Builder b) { return b.getF16Type(); }}, {BF16, [](mlir::Builder b) { return b.getBF16Type(); }}, {F32, [](mlir::Builder b) { return b.getF32Type(); }}, diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index d5fde8a5c7305..16a64cdc22b76 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -6879,6 +6879,13 @@ func.func @f8e5m2(%arg0: tensor) -> tensor { // ----- +func.func @f8e8m0fnu(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @top_k_1d(%arg0 : tensor<16xf32>) { %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) return diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index e388f61aecaa9..ad1b774bc7f15 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -649,6 +649,7 @@ typedef enum { // More truncated 8 bit floating-point formats. PJRT_Buffer_Type_F8E4M3, PJRT_Buffer_Type_F8E3M4, + PJRT_Buffer_Type_F8E8M0FNU, // 4-bit MX floating-point format. PJRT_Buffer_Type_F4E2M1FN, diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index e8384ddb3208c..349430577a0dc 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -310,6 +310,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ; case xla::PrimitiveType::F8E3M4: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4; + case xla::PrimitiveType::F8E8M0FNU: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E8M0FNU; case xla::PrimitiveType::C64: return PJRT_Buffer_Type::PJRT_Buffer_Type_C64; case xla::PrimitiveType::C128: @@ -379,6 +381,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::F8E4M3FNUZ; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4: return xla::PrimitiveType::F8E3M4; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E8M0FNU: + return xla::PrimitiveType::F8E8M0FNU; case PJRT_Buffer_Type::PJRT_Buffer_Type_INVALID: CHECK(false) << "Buffer type is not supported in C API layer."; } diff --git a/xla/primitive_util.h b/xla/primitive_util.h index 9bbf04d6b7e07..70a8335c8bc51 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h @@ -228,6 +228,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return F8E3M4; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F8E8M0FNU; +} + // Complex template <> constexpr PrimitiveType NativeToPrimitiveType() { @@ -382,6 +387,11 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e3m4; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e8m0fnu; +}; + // Complex template <> struct PrimitiveTypeToNative { @@ -414,7 +424,9 @@ inline constexpr bool IsArrayType(PrimitiveType primitive_type) { primitive_type < PrimitiveType_ARRAYSIZE; } -constexpr bool IsMXType(PrimitiveType type) { return type == F4E2M1FN; } +constexpr bool IsMXType(PrimitiveType type) { + return type == F4E2M1FN || type == F8E8M0FNU; +} constexpr bool IsF8Type(PrimitiveType type) { return type == F8E5M2 || type == F8E4M3 || type == F8E4M3FN || @@ -512,6 +524,9 @@ constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { case F8E5M2FNUZ: return std::forward(f)( PrimitiveTypeConstant()); + case F8E8M0FNU: + return std::forward(f)( + PrimitiveTypeConstant()); case F16: return std::forward(f)(PrimitiveTypeConstant()); case BF16: @@ -736,6 +751,10 @@ inline bool CastPreservesValues(PrimitiveType from_type, if (from_type == to_type) { return true; } + // * -> F8E8M0FNU is not possible because zero cannot be represented. + if (to_type == F8E8M0FNU) { + return false; + } // PRED -> * if (from_type == PRED) { return true; diff --git a/xla/primitive_util_test.cc b/xla/primitive_util_test.cc index c980dba72cf75..da1a6afdf3b87 100644 --- a/xla/primitive_util_test.cc +++ b/xla/primitive_util_test.cc @@ -99,6 +99,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S1][F8E5M2FNUZ] = true; expecteds[S1][F8E4M3FNUZ] = true; expecteds[S1][F8E3M4] = true; + expecteds[S1][F8E8M0FNU] = false; expecteds[S2][PRED] = false; expecteds[S2][S1] = false; expecteds[S2][S2] = true; @@ -128,6 +129,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][F8E5M2FNUZ] = true; expecteds[S2][F8E4M3FNUZ] = true; expecteds[S2][F8E3M4] = true; + expecteds[S2][F8E8M0FNU] = false; expecteds[S4][PRED] = false; expecteds[S4][S1] = false; expecteds[S4][S2] = false; @@ -157,6 +159,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][F8E5M2FNUZ] = true; expecteds[S4][F8E4M3FNUZ] = true; expecteds[S4][F8E3M4] = true; + expecteds[S4][F8E8M0FNU] = false; expecteds[S8][PRED] = false; expecteds[S8][S1] = false; expecteds[S8][S2] = false; @@ -186,6 +189,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][F8E5M2FNUZ] = false; expecteds[S8][F8E4M3FNUZ] = false; expecteds[S8][F8E3M4] = false; + expecteds[S8][F8E8M0FNU] = false; expecteds[S16][PRED] = false; expecteds[S16][S1] = false; expecteds[S16][S2] = false; @@ -215,6 +219,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][F8E5M2FNUZ] = false; expecteds[S16][F8E4M3FNUZ] = false; expecteds[S16][F8E3M4] = false; + expecteds[S16][F8E8M0FNU] = false; expecteds[S32][PRED] = false; expecteds[S32][S1] = false; expecteds[S32][S2] = false; @@ -244,6 +249,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][F8E5M2FNUZ] = false; expecteds[S32][F8E4M3FNUZ] = false; expecteds[S32][F8E3M4] = false; + expecteds[S32][F8E8M0FNU] = false; expecteds[S64][PRED] = false; expecteds[S64][S1] = false; expecteds[S64][S2] = false; @@ -273,6 +279,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][F8E5M2FNUZ] = false; expecteds[S64][F8E4M3FNUZ] = false; expecteds[S64][F8E3M4] = false; + expecteds[S64][F8E8M0FNU] = false; expecteds[U1][PRED] = false; expecteds[U1][S1] = false; expecteds[U1][S2] = true; @@ -302,6 +309,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U1][F8E5M2FNUZ] = true; expecteds[U1][F8E4M3FNUZ] = true; expecteds[U1][F8E3M4] = true; + expecteds[U1][F8E8M0FNU] = false; expecteds[U2][PRED] = false; expecteds[U2][S1] = false; expecteds[U2][S2] = false; @@ -331,6 +339,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][F8E5M2FNUZ] = true; expecteds[U2][F8E4M3FNUZ] = true; expecteds[U2][F8E3M4] = true; + expecteds[U2][F8E8M0FNU] = false; expecteds[U4][PRED] = false; expecteds[U4][S1] = false; expecteds[U4][S2] = false; @@ -360,6 +369,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][F8E5M2FNUZ] = false; expecteds[U4][F8E4M3FNUZ] = true; expecteds[U4][F8E3M4] = true; + expecteds[U4][F8E8M0FNU] = false; expecteds[U8][PRED] = false; expecteds[U8][S1] = false; expecteds[U8][S2] = false; @@ -389,6 +399,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][F8E5M2FNUZ] = false; expecteds[U8][F8E4M3FNUZ] = false; expecteds[U8][F8E3M4] = false; + expecteds[U8][F8E8M0FNU] = false; expecteds[U16][PRED] = false; expecteds[U16][S1] = false; expecteds[U16][S2] = false; @@ -418,6 +429,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][F8E5M2FNUZ] = false; expecteds[U16][F8E4M3FNUZ] = false; expecteds[U16][F8E3M4] = false; + expecteds[U16][F8E8M0FNU] = false; expecteds[U32][PRED] = false; expecteds[U32][S1] = false; expecteds[U32][S2] = false; @@ -447,6 +459,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][F8E5M2FNUZ] = false; expecteds[U32][F8E4M3FNUZ] = false; expecteds[U32][F8E3M4] = false; + expecteds[U32][F8E8M0FNU] = false; expecteds[U64][PRED] = false; expecteds[U64][S1] = false; expecteds[U64][S2] = false; @@ -476,6 +489,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][F8E5M2FNUZ] = false; expecteds[U64][F8E4M3FNUZ] = false; expecteds[U64][F8E3M4] = false; + expecteds[U64][F8E8M0FNU] = false; expecteds[F16][PRED] = false; expecteds[F16][S1] = false; expecteds[F16][S2] = false; @@ -505,6 +519,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][F8E5M2FNUZ] = false; expecteds[F16][F8E4M3FNUZ] = false; expecteds[F16][F8E3M4] = false; + expecteds[F16][F8E8M0FNU] = false; expecteds[F32][PRED] = false; expecteds[F32][S1] = false; expecteds[F32][S2] = false; @@ -534,6 +549,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][F8E5M2FNUZ] = false; expecteds[F32][F8E4M3FNUZ] = false; expecteds[F32][F8E3M4] = false; + expecteds[F32][F8E8M0FNU] = false; expecteds[F64][PRED] = false; expecteds[F64][S1] = false; expecteds[F64][S2] = false; @@ -563,6 +579,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][F8E5M2FNUZ] = false; expecteds[F64][F8E4M3FNUZ] = false; expecteds[F64][F8E3M4] = false; + expecteds[F64][F8E8M0FNU] = false; expecteds[C64][PRED] = false; expecteds[C64][S1] = false; expecteds[C64][S2] = false; @@ -592,6 +609,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][F8E5M2FNUZ] = false; expecteds[C64][F8E4M3FNUZ] = false; expecteds[C64][F8E3M4] = false; + expecteds[C64][F8E8M0FNU] = false; expecteds[BF16][PRED] = false; expecteds[BF16][S1] = false; expecteds[BF16][S2] = false; @@ -621,6 +639,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][F8E5M2FNUZ] = false; expecteds[BF16][F8E4M3FNUZ] = false; expecteds[BF16][F8E3M4] = false; + expecteds[BF16][F8E8M0FNU] = false; expecteds[C128][PRED] = false; expecteds[C128][S1] = false; expecteds[C128][S2] = false; @@ -650,6 +669,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][F8E5M2FNUZ] = false; expecteds[C128][F8E4M3FNUZ] = false; expecteds[C128][F8E3M4] = false; + expecteds[C128][F8E8M0FNU] = false; expecteds[F4E2M1FN][PRED] = false; expecteds[F4E2M1FN][S1] = false; expecteds[F4E2M1FN][S2] = false; @@ -679,6 +699,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F4E2M1FN][F8E4M3FNUZ] = false; expecteds[F4E2M1FN][F8E5M2FNUZ] = false; expecteds[F4E2M1FN][F8E3M4] = true; + expecteds[F4E2M1FN][F8E8M0FNU] = false; expecteds[F8E5M2][PRED] = false; expecteds[F8E5M2][S1] = false; expecteds[F8E5M2][S2] = false; @@ -708,6 +729,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][F8E5M2FNUZ] = false; expecteds[F8E5M2][F8E4M3FNUZ] = false; expecteds[F8E5M2][F8E3M4] = false; + expecteds[F8E5M2][F8E8M0FNU] = false; expecteds[F8E4M3][PRED] = false; expecteds[F8E4M3][S1] = false; expecteds[F8E4M3][S2] = false; @@ -737,6 +759,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3][F8E4M3FNUZ] = false; expecteds[F8E4M3][F8E4M3B11FNUZ] = false; expecteds[F8E4M3][F8E3M4] = false; + expecteds[F8E4M3][F8E8M0FNU] = false; expecteds[F8E4M3FN][PRED] = false; expecteds[F8E4M3FN][S1] = false; expecteds[F8E4M3FN][S2] = false; @@ -766,6 +789,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][F8E4M3FNUZ] = false; expecteds[F8E4M3FN][F8E4M3B11FNUZ] = false; expecteds[F8E4M3FN][F8E3M4] = false; + expecteds[F8E4M3FN][F8E8M0FNU] = false; expecteds[F8E4M3B11FNUZ][PRED] = false; expecteds[F8E4M3B11FNUZ][S1] = false; expecteds[F8E4M3B11FNUZ][S2] = false; @@ -795,6 +819,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][F8E4M3FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E3M4] = false; + expecteds[F8E4M3B11FNUZ][F8E8M0FNU] = false; expecteds[F8E5M2FNUZ][PRED] = false; expecteds[F8E5M2FNUZ][S1] = false; expecteds[F8E5M2FNUZ][S2] = false; @@ -824,6 +849,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][F8E5M2FNUZ] = true; expecteds[F8E5M2FNUZ][F8E4M3FNUZ] = false; expecteds[F8E5M2FNUZ][F8E3M4] = false; + expecteds[F8E5M2FNUZ][F8E8M0FNU] = false; expecteds[F8E4M3FNUZ][PRED] = false; expecteds[F8E4M3FNUZ][S1] = false; expecteds[F8E4M3FNUZ][S2] = false; @@ -853,6 +879,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3FNUZ][F8E4M3FNUZ] = true; expecteds[F8E4M3FNUZ][F8E3M4] = false; + expecteds[F8E4M3FNUZ][F8E8M0FNU] = false; expecteds[F8E3M4][PRED] = false; expecteds[F8E3M4][S1] = false; expecteds[F8E3M4][S2] = false; @@ -882,6 +909,37 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E3M4][F8E4M3FNUZ] = false; expecteds[F8E3M4][F8E4M3B11FNUZ] = false; expecteds[F8E3M4][F8E3M4] = true; + expecteds[F8E3M4][F8E8M0FNU] = false; + expecteds[F8E8M0FNU][PRED] = false; + expecteds[F8E8M0FNU][S1] = false; + expecteds[F8E8M0FNU][S2] = false; + expecteds[F8E8M0FNU][S4] = false; + expecteds[F8E8M0FNU][S8] = false; + expecteds[F8E8M0FNU][S16] = false; + expecteds[F8E8M0FNU][S32] = false; + expecteds[F8E8M0FNU][S64] = false; + expecteds[F8E8M0FNU][U1] = false; + expecteds[F8E8M0FNU][U2] = false; + expecteds[F8E8M0FNU][U4] = false; + expecteds[F8E8M0FNU][U8] = false; + expecteds[F8E8M0FNU][U16] = false; + expecteds[F8E8M0FNU][U32] = false; + expecteds[F8E8M0FNU][U64] = false; + expecteds[F8E8M0FNU][F16] = false; + expecteds[F8E8M0FNU][F32] = true; + expecteds[F8E8M0FNU][F64] = true; + expecteds[F8E8M0FNU][C64] = true; + expecteds[F8E8M0FNU][BF16] = true; + expecteds[F8E8M0FNU][C128] = true; + expecteds[F8E8M0FNU][F4E2M1FN] = false; + expecteds[F8E8M0FNU][F8E5M2] = false; + expecteds[F8E8M0FNU][F8E4M3] = false; + expecteds[F8E8M0FNU][F8E4M3FN] = false; + expecteds[F8E8M0FNU][F8E4M3B11FNUZ] = false; + expecteds[F8E8M0FNU][F8E4M3FNUZ] = false; + expecteds[F8E8M0FNU][F8E5M2FNUZ] = false; + expecteds[F8E8M0FNU][F8E3M4] = false; + expecteds[F8E8M0FNU][F8E8M0FNU] = true; for (int from_type_int = PrimitiveType_MIN; from_type_int < PrimitiveType_ARRAYSIZE; ++from_type_int) { diff --git a/xla/python/ifrt/dtype.cc b/xla/python/ifrt/dtype.cc index 3f4b35247c32c..4ba15201bcf86 100644 --- a/xla/python/ifrt/dtype.cc +++ b/xla/python/ifrt/dtype.cc @@ -40,6 +40,7 @@ std::optional DType::byte_size() const { case kU8: case kF8E3M4: case kF8E4M3: + case kF8E8M0FNU: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -85,6 +86,7 @@ std::optional DType::bit_size() const { case kU8: case kF8E3M4: case kF8E4M3: + case kF8E8M0FNU: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -147,6 +149,7 @@ absl::StatusOr DType::FromProto(const DTypeProto& dtype_proto) { // CASE(F4E2M1FN); // CASE(F8E3M4); // CASE(F8E4M3); + // CASE(F8E8M0FNU); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); @@ -196,6 +199,7 @@ DTypeProto DType::ToProto() const { // CASE(F4E2M1FN); // CASE(F8E3M4); // CASE(F8E4M3); + // CASE(F8E8M0FNU); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); diff --git a/xla/python/ifrt/dtype.h b/xla/python/ifrt/dtype.h index f00ceae937d01..ff724df90a230 100644 --- a/xla/python/ifrt/dtype.h +++ b/xla/python/ifrt/dtype.h @@ -88,11 +88,12 @@ class DType { kF8E4M3FNUZ = 25, kF8E5M2 = 19, kF8E5M2FNUZ = 24, + kF8E8M0FNU = 31, // MX floating point types. kF4E2M1FN = 30, - // Next = 31 + // Next = 32 // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind diff --git a/xla/python/ifrt/dtype.proto b/xla/python/ifrt/dtype.proto index 9b3d7c8d1923a..8440ce792b93a 100644 --- a/xla/python/ifrt/dtype.proto +++ b/xla/python/ifrt/dtype.proto @@ -70,6 +70,7 @@ message DTypeProto { KIND_F8E4M3FNUZ = 25; KIND_F8E5M2 = 19; KIND_F8E5M2FNUZ = 24; + KIND_F8E8M0FNU = 31; // MX floating point types. KIND_F4E2M1FN = 30; diff --git a/xla/python/ifrt/dtype_test.cc b/xla/python/ifrt/dtype_test.cc index 384117353cf27..55051f0b29cfc 100644 --- a/xla/python/ifrt/dtype_test.cc +++ b/xla/python/ifrt/dtype_test.cc @@ -49,14 +49,15 @@ TEST(DTypeTest, ByteSize) { {DType::kF8E3M4, 1}, {DType::kF8E4M3, 1}, {DType::kF8E4M3FN, 1}, {DType::kF8E4M3B11FNUZ, 1}, {DType::kF8E4M3FNUZ, 1}, {DType::kF8E5M2, 1}, - {DType::kF8E5M2FNUZ, 1}, {DType::kS16, 2}, - {DType::kU16, 2}, {DType::kF16, 2}, - {DType::kBF16, 2}, {DType::kS32, 4}, - {DType::kU32, 4}, {DType::kF32, 4}, - {DType::kS64, 8}, {DType::kU64, 8}, - {DType::kF64, 8}, {DType::kC64, 8}, - {DType::kC128, 16}, {DType::kToken, -1}, - {DType::kInvalid, -1}, {DType::kString, -1}, + {DType::kF8E5M2FNUZ, 1}, {DType::kF8E8M0FNU, 1}, + {DType::kS16, 2}, {DType::kU16, 2}, + {DType::kF16, 2}, {DType::kBF16, 2}, + {DType::kS32, 4}, {DType::kU32, 4}, + {DType::kF32, 4}, {DType::kS64, 8}, + {DType::kU64, 8}, {DType::kF64, 8}, + {DType::kC64, 8}, {DType::kC128, 16}, + {DType::kToken, -1}, {DType::kInvalid, -1}, + {DType::kString, -1}, })) { EXPECT_EQ(DType(kind).byte_size(), byte_size == -1 ? std::nullopt : std::make_optional(byte_size)); @@ -73,14 +74,15 @@ TEST(DTypeTest, BitSize) { {DType::kF8E3M4, 8}, {DType::kF8E4M3, 8}, {DType::kF8E4M3FN, 8}, {DType::kF8E4M3B11FNUZ, 8}, {DType::kF8E4M3FNUZ, 8}, {DType::kF8E5M2, 8}, - {DType::kF8E5M2FNUZ, 8}, {DType::kS16, 16}, - {DType::kU16, 16}, {DType::kF16, 16}, - {DType::kBF16, 16}, {DType::kS32, 32}, - {DType::kU32, 32}, {DType::kF32, 32}, - {DType::kS64, 64}, {DType::kU64, 64}, - {DType::kF64, 64}, {DType::kC64, 64}, - {DType::kC128, 128}, {DType::kToken, -1}, - {DType::kInvalid, -1}, {DType::kString, -1}, + {DType::kF8E5M2FNUZ, 8}, {DType::kF8E8M0FNU, 4}, + {DType::kS16, 16}, {DType::kU16, 16}, + {DType::kF16, 16}, {DType::kBF16, 16}, + {DType::kS32, 32}, {DType::kU32, 32}, + {DType::kF32, 32}, {DType::kS64, 64}, + {DType::kU64, 64}, {DType::kF64, 64}, + {DType::kC64, 64}, {DType::kC128, 128}, + {DType::kToken, -1}, {DType::kInvalid, -1}, + {DType::kString, -1}, })) { EXPECT_EQ(DType(kind).bit_size(), bit_size == -1 ? std::nullopt : std::make_optional(bit_size)); diff --git a/xla/python/pjrt_ifrt/pjrt_dtype.cc b/xla/python/pjrt_ifrt/pjrt_dtype.cc index fc2f55d9d0976..2af3281a588cc 100644 --- a/xla/python/pjrt_ifrt/pjrt_dtype.cc +++ b/xla/python/pjrt_ifrt/pjrt_dtype.cc @@ -52,6 +52,7 @@ absl::StatusOr ToPrimitiveType(DType dtype) { CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ); CASE(DType::kF8E5M2, xla::PrimitiveType::F8E5M2); CASE(DType::kF8E5M2FNUZ, xla::PrimitiveType::F8E5M2FNUZ); + CASE(DType::kF8E8M0FNU, xla::PrimitiveType::F8E8M0FNU); CASE(DType::kF16, xla::PrimitiveType::F16); CASE(DType::kF32, xla::PrimitiveType::F32); CASE(DType::kBF16, xla::PrimitiveType::BF16); @@ -92,6 +93,7 @@ absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::F8E4M3FNUZ: case xla::PrimitiveType::F8E5M2: case xla::PrimitiveType::F8E5M2FNUZ: + case xla::PrimitiveType::F8E8M0FNU: case xla::PrimitiveType::F16: case xla::PrimitiveType::F32: case xla::PrimitiveType::BF16: diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index 334bf18ef0623..19dfdfc803f6c 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -208,6 +208,9 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E5M2FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E8M0FNU; } else if (std::is_same() || !options.squash_64bit_types) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<0>()); type = primitive_util::NativeToPrimitiveType(); @@ -422,6 +425,10 @@ absl::StatusOr DevicePut(nb::handle arg, HandleNumpyScalar; (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = HandleNumpyScalar; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = + HandleNumpyScalar; + } (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; @@ -605,6 +612,7 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, // (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler; // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; diff --git a/xla/python/types.cc b/xla/python/types.cc index 215d802ec069d..c1653279c7858 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -66,6 +66,7 @@ struct CustomDtypes { nb_dtype float8_e4m3fnuz; nb_dtype float8_e5m2; nb_dtype float8_e5m2fnuz; + std::optional float8_e8m0fnu; std::optional int2; nb_dtype int4; std::optional uint2; @@ -96,6 +97,10 @@ const CustomDtypes& GetCustomDtypes() { nb_dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")); dtypes->float8_e5m2fnuz = nb_dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")); + if (nb::hasattr(ml_dtypes, "float8_e8m0fnu")) { + dtypes->float8_e8m0fnu = + nb_dtype::from_args(ml_dtypes.attr("float8_e8m0fnu")); + } dtypes->int4 = nb_dtype::from_args(ml_dtypes.attr("int4")); dtypes->uint4 = nb_dtype::from_args(ml_dtypes.attr("uint4")); if (nb::hasattr(ml_dtypes, "int2")) { @@ -166,6 +171,9 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { map->emplace(custom_dtypes.float8_e4m3fnuz, F8E4M3FNUZ); map->emplace(custom_dtypes.float8_e5m2, F8E5M2); map->emplace(custom_dtypes.float8_e5m2fnuz, F8E5M2FNUZ); + if (custom_dtypes.float8_e8m0fnu.has_value()) { + map->emplace(*custom_dtypes.float8_e8m0fnu, F8E8M0FNU); + } if (custom_dtypes.int2.has_value()) { map->emplace(*custom_dtypes.int2, S2); } @@ -250,6 +258,11 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return custom_dtypes.float8_e5m2; case F8E5M2FNUZ: return custom_dtypes.float8_e5m2fnuz; + case F8E8M0FNU: + if (custom_dtypes.float8_e8m0fnu.has_value()) { + return *custom_dtypes.float8_e8m0fnu; + } + break; case BF16: return custom_dtypes.bfloat16; case F16: @@ -345,6 +358,11 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return custom_dtypes.float8_e5m2; case ifrt::DType::kF8E5M2FNUZ: return custom_dtypes.float8_e5m2fnuz; + case ifrt::DType::kF8E8M0FNU: + if (custom_dtypes.float8_e8m0fnu.has_value()) { + return *custom_dtypes.float8_e8m0fnu; + } + break; case ifrt::DType::kString: // PEP 3118 code for "pointer to Python Object". We use Python objects // instead of 'U' (Unicode string) or 'V' (raw data) because the latter @@ -413,6 +431,9 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_float8_e5m2 = nb::object(ml_dtypes.attr("float8_e5m2")); dtypes->np_float8_e4m3fnuz = nb::object(ml_dtypes.attr("float8_e4m3fnuz")); dtypes->np_float8_e5m2fnuz = nb::object(ml_dtypes.attr("float8_e5m2fnuz")); + if (nb::hasattr(ml_dtypes, "float8_e8m0fnu")) { + dtypes->np_float8_e8m0fnu = nb::object(ml_dtypes.attr("float8_e8m0fnu")); + } dtypes->np_float16 = nb::object(numpy.attr("float16")); dtypes->np_float32 = nb::object(numpy.attr("float32")); dtypes->np_float64 = nb::object(numpy.attr("float64")); diff --git a/xla/python/types.h b/xla/python/types.h index 1476e179b5988..babdf5a9bd416 100644 --- a/xla/python/types.h +++ b/xla/python/types.h @@ -89,6 +89,7 @@ struct NumpyScalarTypes { nanobind::object np_float8_e4m3fnuz; nanobind::object np_float8_e5m2; nanobind::object np_float8_e5m2fnuz; + std::optional np_float8_e8m0fnu; nanobind::object np_float16; nanobind::object np_float32; nanobind::object np_float64; diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 84902b4dc1d01..0be7475661a57 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -208,6 +208,7 @@ NB_MODULE(xla_extension, m) { // .value("F4E2M1FN", F4E2M1FN) // .value("F8E3M4", F8E3M4) // .value("F8E4M3", F8E4M3) + // .value("F8E8M0FNU", F8E8M0FNU) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) .value("F8E4M3FNUZ", F8E4M3FNUZ) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 84733f2c7a43e..f78c136b5f016 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -283,6 +283,7 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): # float4_e2m1fn = ml_dtypes.float4_e2m1fn # float8_e3m4 = ml_dtypes.float8_e3m4 # float8_e4m3 = ml_dtypes.float8_e4m3 +# float8_e8m0fnu = ml_dtypes.float8_e8m0fnu float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz @@ -305,6 +306,7 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): # PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), + # PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), PrimitiveType.F8E5M2: np.dtype(float8_e5m2), diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index 957b6ed946bc4..ff1d67059daf4 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -65,6 +65,7 @@ bfloat16: type[numpy.generic] # float4_e2m1fn: type[numpy.generic] # float8_e3m4: type[numpy.generic] # float8_e4m3: type[numpy.generic] +# float8_e8m0fnu: type[numpy.generic] float8_e4m3fn: type[numpy.generic] float8_e4m3b11fnuz: type[numpy.generic] float8_e4m3fnuz: type[numpy.generic] diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index 27cf64b3ba49d..37718e3fa8790 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -58,6 +58,7 @@ # float4_e2m1fn = xla_client.float4_e2m1fn # float8_e3m4 = xla_client.float8_e3m4 # float8_e4m3 = xla_client.float8_e4m3 +# float8_e8m0fnu = xla_client.float8_e8m0fnu float8_e4m3fn = xla_client.float8_e4m3fn float8_e4m3fnuz = xla_client.float8_e4m3fnuz float8_e4m3b11fnuz = xla_client.float8_e4m3b11fnuz @@ -190,7 +191,7 @@ def TestFactory(xla_backend, fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] standard_dtypes += fp8_dtypes # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - # standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3] + # standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3, float8_e8m0fnu] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index a1942ff3d187f..ee7df05462f7b 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -82,6 +82,7 @@ class PrimitiveType(enum.IntEnum): F8E4M3FNUZ: PrimitiveType F8E5M2: PrimitiveType F8E5M2FNUZ: PrimitiveType + F8E8M0FNU: PrimitiveType BF16: PrimitiveType F16: PrimitiveType F32: PrimitiveType diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index bc64f6dcb419e..e34ec05d9e789 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -604,6 +604,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(&u4_support); FloatSupport f4e2m1fn_support(F4E2M1FN, F16); pipeline.AddPass(&f4e2m1fn_support); + FloatSupport f8e8m0fnu_support(F8E8M0FNU, F32); + pipeline.AddPass(&f8e8m0fnu_support); // After canonicalization, there may be more batch dots that can be // simplified. pipeline.AddPass(); diff --git a/xla/service/cpu/onednn_memory_util.h b/xla/service/cpu/onednn_memory_util.h index 2a9fa25ca185b..90c4f6c82e408 100644 --- a/xla/service/cpu/onednn_memory_util.h +++ b/xla/service/cpu/onednn_memory_util.h @@ -73,7 +73,7 @@ inline dnnl::memory::data_type ToOneDnnDataType(PrimitiveType ptype) { // TODO(intel-tf): properly handle not supported types: // S16, S64, U16, U32, U64, C64, C128, F8E5M2, F8E4M3FN, S4, U4, - // F8E4M3B11FNUZ, F8E4M3, F8E3M4, F4E2M1FN + // F8E4M3B11FNUZ, F8E4M3, F8E3M4, F4E2M1FN, F8E8M0FNU default: return dt::undef; } diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index f4e5b4faee6a8..46cdd1598d6ab 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -885,6 +885,47 @@ llvm::Value* EmitF4e2m1fnToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { return b->CreateBitCast(f16_signed, b->getHalfTy()); } +llvm::Value* EmitF32ToF8e8m0fnu(llvm::Value* f32_value, llvm::IRBuilder<>* b) { + llvm::Value* as_int32 = b->CreateBitCast(f32_value, b->getInt32Ty()); + + // Result is NaN if input is zero, negative, infinity or NaN. + auto i32_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt32Ty(), val); + }; + llvm::Value* is_denorm = b->CreateICmpULE(as_int32, i32_const(0x400000)); + llvm::Value* is_nan = + b->CreateOr(b->CreateICmpSLE(as_int32, i32_const(0)), + b->CreateICmpSGE(as_int32, i32_const(0x7F400000))); + + // Round the value and extract exponent. + llvm::Value* rounded = b->CreateAdd(as_int32, i32_const(0x400000)); + llvm::Value* shifted = b->CreateAShr(rounded, 23); + llvm::Value* finite = b->CreateSelect(is_denorm, i32_const(0x00), shifted); + llvm::Value* f32_result = b->CreateSelect(is_nan, i32_const(0xFF), finite); + + // Truncate to the result type. + return b->CreateTrunc(f32_result, b->getInt8Ty()); +} + +llvm::Value* EmitF8e8m0fnuToF32(llvm::Value* f8_value, llvm::IRBuilder<>* b) { + // Shift exponent to the left for the normal case. + llvm::Value* as_int32 = b->CreateZExt(f8_value, b->getInt32Ty()); + llvm::Value* shifted = b->CreateShl(as_int32, 23); + + // Special values for 0x00 (denorm) and 0xFF (NaN). + auto i32_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt32Ty(), val); + }; + llvm::Value* is_zero = b->CreateICmpEQ(as_int32, i32_const(0)); + llvm::Value* is_nan = b->CreateICmpEQ(as_int32, i32_const(0xFF)); + llvm::Value* denorm_or_shifted = + b->CreateSelect(is_zero, i32_const(0x00400000), shifted); + llvm::Value* f32_result = + b->CreateSelect(is_nan, i32_const(0x7FC00000), denorm_or_shifted); + + return b->CreateBitCast(f32_result, b->getFloatTy()); +} + llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, PrimitiveType from_type, PrimitiveType to_type, llvm::Module* module, @@ -985,6 +1026,12 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), b_); } + if (to_type == F8E8M0FNU) { + return EmitF32ToF8e8m0fnu( + EmitIntegralToFloating(operand_value, from_type, F32, module_, + b_), + b_); + } if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz( F16, @@ -1198,10 +1245,21 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F8E8M0FNU) { + TF_RET_CHECK(to_type != F8E8M0FNU); + operand_value = EmitF8e8m0fnuToF32(operand_value, b_); + from_type = F32; + if (from_type == to_type) { + return operand_value; + } + } if (from_type == F8E5M2FNUZ || from_type == F8E4M3FNUZ) { TF_RET_CHECK(to_type != from_type); PrimitiveType cast_type = primitive_util::IsFloatingPointType(to_type) ? to_type : F16; + if (to_type == F8E8M0FNU) { + cast_type = F32; + } TF_ASSIGN_OR_RETURN(operand_value, EmitF8fnuzToFloating(from_type, operand_value, cast_type, b_, module_)); @@ -1282,6 +1340,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitF16ToF4e2m1fn(operand_value, b_); } + if (to_type == F8E8M0FNU) { + // Cast to F32 first. Casts to F8E8M0FNU must be from F32. + if (from_type != F32) { + operand_value = b_->CreateFPCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(F32, module_)); + } + return EmitF32ToF8e8m0fnu(operand_value, b_); + } if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz(from_type, operand_value, to_type, b_); } @@ -1835,6 +1901,9 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( } else if (operand_type == F4E2M1FN) { lhs_value = EmitF4e2m1fnToF16(lhs_value, b_); rhs_value = EmitF4e2m1fnToF16(rhs_value, b_); + } else if (operand_type == F8E8M0FNU) { + lhs_value = EmitF8e8m0fnuToF32(lhs_value, b_); + rhs_value = EmitF8e8m0fnuToF32(rhs_value, b_); } else if (operand_type == F8E5M2FNUZ || operand_type == F8E4M3FNUZ) { TF_ASSIGN_OR_RETURN( lhs_value, diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index bfd40fd5d6ee0..b3f4b8ddef894 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -102,7 +102,7 @@ using FloatTypes = ::testing::Types; + tsl::float8_e5m2fnuz, tsl::float8_e8m0fnu>; TYPED_TEST_SUITE(ElementalIrEmitterExecutionTypedTest, FloatTypes); @@ -615,7 +615,8 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { std::is_same() || std::is_same() || std::is_same() || - std::is_same()) { + std::is_same() || + std::is_same()) { GTEST_SKIP() << "Skipping test for type " << tname; } const auto hlo_text = absl::StrReplaceAll(R"( @@ -630,8 +631,9 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { TYPED_TEST(ElementalIrEmitterExecutionTypedTest, BatchDotFloat) { auto tname = this->TypeName(); - if (std::is_same()) { - GTEST_SKIP() << "Dot operation on E2M1 is not supported"; + if (std::is_same() || + std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; } const auto hlo_text = absl::StrReplaceAll(R"( HloModule matmul diff --git a/xla/service/float8_fnuz_ir_emitter.cc b/xla/service/float8_fnuz_ir_emitter.cc index b0c012a75f9f0..e0be95da5f668 100644 --- a/xla/service/float8_fnuz_ir_emitter.cc +++ b/xla/service/float8_fnuz_ir_emitter.cc @@ -56,6 +56,8 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( return &llvm::APFloat::Float8E5M2(); case F8E5M2FNUZ: return &llvm::APFloat::Float8E5M2FNUZ(); + case F8E8M0FNU: + return &llvm::APFloat::Float8E8M0FNU(); case BF16: return &llvm::APFloat::BFloat(); case F16: @@ -83,6 +85,7 @@ absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilderBase* b, case F8E4M3FNUZ: case F8E5M2: case F8E5M2FNUZ: + case F8E8M0FNU: return b->getInt8Ty(); case BF16: return b->getBFloatTy(); diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 6c45121ae1ff5..da62f8c05c9e8 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1492,6 +1492,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( const GpuFloatSupport s4_support(gpu_version, S4, S8); const GpuFloatSupport u4_support(gpu_version, U4, U8); const GpuFloatSupport f4e2m1fn_support(gpu_version, F4E2M1FN, F16); + const GpuFloatSupport f8e8m0fnu_support(gpu_version, F8E8M0FNU, F32); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = pipeline.AddPass("float_normalization"); @@ -1506,6 +1507,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( sub_pipeline.AddPass(&s4_support); sub_pipeline.AddPass(&u4_support); sub_pipeline.AddPass(&f4e2m1fn_support); + sub_pipeline.AddPass(&f8e8m0fnu_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_allow_excess_precision()) { sub_pipeline.AddPass(); diff --git a/xla/service/gpu/tests/float_conversions_test.cc b/xla/service/gpu/tests/float_conversions_test.cc index 62179eefcf3c5..6e0e14e320a7f 100644 --- a/xla/service/gpu/tests/float_conversions_test.cc +++ b/xla/service/gpu/tests/float_conversions_test.cc @@ -31,7 +31,8 @@ INSTANTIATE_TEST_SUITE_P(FloatConversionParamSuite, FloatConversionParamTest, ::testing::Values("f64", "f32", "f16", "bf16", "f4e2m1fn", "f8e5m2", "f8e5m2fnuz", "f8e4m3", "f8e4m3fn", "f8e4m3fnuz", - "f8e4m3b11fnuz", "f8e3m4")); + "f8e4m3b11fnuz", "f8e3m4", + "f8e8m0fnu")); TEST_P(FloatConversionParamTest, FloatToF16) { auto type_name = GetParam(); diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index c4be8818f5ff6..b937dbc1500b6 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -208,6 +208,7 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, case F8E4M3B11FNUZ: case F8E4M3FNUZ: case F8E3M4: + case F8E8M0FNU: // We represent F8 as an int since there is no LLVM F8 dtype. return llvm::Type::getInt8Ty(context); case BF16: diff --git a/xla/stream_executor/data_type.h b/xla/stream_executor/data_type.h index dc77c0f711a45..e3e7d1f17e312 100644 --- a/xla/stream_executor/data_type.h +++ b/xla/stream_executor/data_type.h @@ -65,6 +65,10 @@ struct ToDataType { static constexpr DataType value = DataType::kF8E5M2FNUZ; }; template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E8M0FNU; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kFloat; }; diff --git a/xla/stream_executor/dnn.cc b/xla/stream_executor/dnn.cc index 5d41b3152790b..24851e56d75ed 100644 --- a/xla/stream_executor/dnn.cc +++ b/xla/stream_executor/dnn.cc @@ -76,6 +76,7 @@ constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; diff --git a/xla/stream_executor/gpu/gpu_blas_lt.cc b/xla/stream_executor/gpu/gpu_blas_lt.cc index 41db3c6c90104..182af599af9e5 100644 --- a/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -58,6 +58,8 @@ absl::StatusOr AsBlasDataType(PrimitiveType dtype) { return DataType::kF8E3M4; case PrimitiveType::F4E2M1FN: return DataType::kF4E2M1FN; + case PrimitiveType::F8E8M0FNU: + return DataType::kF8E8M0FNU; case PrimitiveType::S8: return DataType::kInt8; case PrimitiveType::F16: @@ -97,6 +99,8 @@ absl::StatusOr AsXlaPrimitiveType(DataType dtype) { return PrimitiveType::F8E3M4; case DataType::kF4E2M1FN: return PrimitiveType::F4E2M1FN; + case DataType::kF8E8M0FNU: + return PrimitiveType::F8E8M0FNU; case DataType::kInt8: return PrimitiveType::S8; case DataType::kHalf: @@ -159,6 +163,7 @@ absl::StatusOr GetBlasComputationType( case PrimitiveType::F8E4M3FNUZ: // fall-through case PrimitiveType::F8E3M4: // fall-through case PrimitiveType::F4E2M1FN: // fall-through + case PrimitiveType::F8E8M0FNU: // fall-through case PrimitiveType::F16: // fall-through case PrimitiveType::BF16: // Accumulate in f32 precision. diff --git a/xla/stream_executor/rocm/hip_blas_utils.cc b/xla/stream_executor/rocm/hip_blas_utils.cc index dd1cbcf2ea29c..8864476bf0d82 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/xla/stream_executor/rocm/hip_blas_utils.cc @@ -40,8 +40,9 @@ hipDataType AsHipblasDataType(blas::DataType type) { case blas::DataType::kF8E4M3FN: case blas::DataType::kF8E3M4: case blas::DataType::kF4E2M1FN: + case blas::DataType::kF8E8M0FNU: LOG(FATAL) << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN, " - "F8E3M4 and F4E2M1FN"; + "F8E3M4, F4E2M1FN and F8E8M0FNU"; #if TF_ROCM_VERSION >= 60000 case blas::DataType::kF8E5M2FNUZ: return HIP_R_8F_E5M2_FNUZ; diff --git a/xla/tests/BUILD b/xla/tests/BUILD index edb19ff456376..2738d10a6c4cb 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -883,6 +883,7 @@ xla_test( "//xla:shape_util", "//xla:test", "//xla:types", + "//xla:util", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", diff --git a/xla/tests/array_elementwise_ops_test.cc b/xla/tests/array_elementwise_ops_test.cc index 1054c72fea658..b4fc4932a6a11 100644 --- a/xla/tests/array_elementwise_ops_test.cc +++ b/xla/tests/array_elementwise_ops_test.cc @@ -47,6 +47,7 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/types.h" +#include "xla/util.h" #include "tsl/platform/ml_dtypes.h" #if TENSORFLOW_USE_ROCM @@ -93,6 +94,20 @@ std::pair, std::vector> AllSignedPairs( return {xs, ys}; } +template +void AddNegativeValuesMaybeRemoveZero(std::vector& values) { + values.reserve(values.size() * 2); + if (!has_zero_v) { + values.erase(values.begin()); + } + for (size_t i = 0, n = values.size(); i < n; ++i) { + auto neg = -values[i]; + if (SignAndMagnitude(neg).first) { + values.push_back(neg); + } + } +} + class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: static constexpr float kEpsF32 = std::numeric_limits::epsilon(); @@ -1371,14 +1386,7 @@ class TotalOrderTest : public ClientLibraryTestBase { values.push_back(Eigen::numext::abs(std::numeric_limits::quiet_NaN())); } #endif - values.reserve(values.size() * 2); - for (size_t i = 0, n = values.size(); i < n; ++i) { - auto value = values[i]; - auto neg = -value; - if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { - values.push_back(neg); - } - } + AddNegativeValuesMaybeRemoveZero(values); std::vector lhs_data; std::vector rhs_data; lhs_data.reserve(values.size() * values.size()); @@ -1423,20 +1431,21 @@ class TotalOrderTest : public ClientLibraryTestBase { } }; -using Types = ::testing::Types; + float>; TYPED_TEST_SUITE(TotalOrderTest, Types); @@ -1463,13 +1472,7 @@ TYPED_TEST(TotalOrderTest, LargeMagnitudeVsNaN) { if constexpr (std::numeric_limits::has_infinity) { values.push_back(std::numeric_limits::infinity()); } - for (size_t i = 0, n = values.size(); i < n; ++i) { - auto value = values[i]; - auto neg = -value; - if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { - values.push_back(neg); - } - } + AddNegativeValuesMaybeRemoveZero(values); auto lhs = ConstantR1(&builder, values); auto rhs = ConstantR1( &builder, diff --git a/xla/tests/constants_test.cc b/xla/tests/constants_test.cc index f6cebe4e6bf74..2018f88079f86 100644 --- a/xla/tests/constants_test.cc +++ b/xla/tests/constants_test.cc @@ -48,11 +48,10 @@ class ConstantsTest : public ClientLibraryTestBase { template class ConstantsFloatTest : public ConstantsTest {}; -using FloatTypes = - ::testing::Types; +using FloatTypes = ::testing::Types< + float, half, tsl::float4_e2m1fn, tsl::float8_e3m4, tsl::float8_e4m3, + tsl::float8_e4m3fn, tsl::float8_e4m3b11fnuz, tsl::float8_e4m3fnuz, + tsl::float8_e5m2, tsl::float8_e5m2fnuz, tsl::float8_e8m0fnu>; TYPED_TEST_SUITE(ConstantsFloatTest, FloatTypes); diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index 75d9653c3e700..8d7e32b05e3c5 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -53,10 +53,12 @@ class ConvertTestT : public ConvertTest { public: using ConvertTest::ConvertTest; }; -using FloatingPointTypeList = ::testing::Types< - tsl::float4_e2m1fn, tsl::float8_e3m4, tsl::float8_e4m3, tsl::float8_e4m3fn, - tsl::float8_e4m3fnuz, tsl::float8_e4m3b11fnuz, tsl::float8_e5m2, - tsl::float8_e5m2fnuz, Eigen::half, bfloat16, float, double>; +using FloatingPointTypeList = + ::testing::Types; TYPED_TEST_SUITE(ConvertTestT, FloatingPointTypeList); template @@ -744,7 +746,8 @@ XLA_TYPED_TEST(ConvertTestT, ConvertFPToPred) { auto a = ConstantR1(&builder, {FP{0.0}, FP{0.5}, FP{2.0}, FP{-0.0}}); ConvertElementType(a, PRED); - std::array expected = {false, true, true, false}; + bool zero_pred = !has_zero_v; + std::array expected = {zero_pred, true, true, zero_pred}; this->template ComputeAndCompareR1(&builder, expected, {}); } @@ -2085,5 +2088,112 @@ XLA_TYPED_TEST(ConvertTestF16, ConvertF4e2m1fnF16RoundtripExhaustive4) { this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F8E8M0FNU + +XLA_TEST_F(ConvertTest, ConvertF32F8e8m0fnuRoundtrip) { + // Convert from FP32 to FP8, then back to FP32. + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, nan}, // No zero values + {-0.0, nan}, + {1.0, 1.0}, + {-1.0, nan}, // No negative values + {nan, nan}, + {inf, nan}, + // clang-format on + {0x1.8p1, 0x1p2}, // Round-to-even up + {0x1.8p2, 0x1p3}, // Round-to-even up (not a mistake) + {0x1p127, 0x1p127}, // Max value + {0x1.7FFFFEp127, 0x1p127}, // Largest number that doesn't overflow + {0x1.8p127, nan}, // Smallest number that overflows + {0x1.FFFFFEp127, nan}, // Overflow + {0x1p-126, 0x1p-126}, // Smallest F8 normal + {0x0.800002p-126, 0x1p-126}, // Smallest number rounding up to normal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E8M0FNU); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e8m0fnuRoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e8m0fnu; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E8M0FNU); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +// This test is disabled on CPU, as converting 0x1p-127 from double to float +// using CVTSD2SS on x64 results in an underflow (even though the result is +// representable as denormalized float32). +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_CPU(ConvertF8e8m0fnuRoundtripExhaustive2)) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back( + static_cast(Eigen::numext::bit_cast( + static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), F8E8M0FNU); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e8m0fnuRoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e8m0fnu; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e8m0fnuF16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f16; + for (int i = 0; i < 65536; i++) { + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f16), F8E8M0FNU); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + } // namespace } // namespace xla diff --git a/xla/tools/driver.cc b/xla/tools/driver.cc index 2db53e65cf75c..d1d6882b6532a 100644 --- a/xla/tools/driver.cc +++ b/xla/tools/driver.cc @@ -129,16 +129,19 @@ enum PrimitiveType { F8E5M2FNUZ, F8E4M3FNUZ, F8E3M4, + F8E8M0FNU, }; const std::vector& primitive_strings() { static auto vec = new std::vector( - {"s1", "s2", "s4", "s8", "s16", - "s32", "s64", "u1", "u2", "u4", - "u8", "u16", "u32", "u64", "f16", - "bf16", "f32", "f64", "c64", "c128", - "f4e2m1fn", "f8e5m2", "f8e4m3", "f8e4m3fn", "f8e4m3b11fnuz", - "f8e5m2fnuz", "f8e4m3fnuz", "f8e3m4"}); + {"s1", "s2", "s4", "s8", + "s16", "s32", "s64", "u1", + "u2", "u4", "u8", "u16", + "u32", "u64", "f16", "bf16", + "f32", "f64", "c64", "c128", + "f4e2m1fn", "f8e3m4", "f8e4m3", "f8e4m3b11fnuz", + "f8e4m3fn", "f8e4m3fnuz", "f8e5m2", "f8e5m2fnuz", + "f8e8m0fnu"}); return *vec; } @@ -423,6 +426,7 @@ void Fill(void* buffer, const ArrayShape& shape) { case F8E5M2FNUZ: case F8E4M3FNUZ: case F8E3M4: + case F8E8M0FNU: case F16: case BF16: case C64: @@ -484,6 +488,7 @@ void Display(const void* buffer, const ArrayShape& shape) { case F8E5M2FNUZ: case F8E4M3FNUZ: case F8E3M4: + case F8E8M0FNU: case F16: case BF16: case C64: diff --git a/xla/tsl/framework/type_traits.h b/xla/tsl/framework/type_traits.h index 0a0deb5f5f66b..bda9244327b4f 100644 --- a/xla/tsl/framework/type_traits.h +++ b/xla/tsl/framework/type_traits.h @@ -77,7 +77,8 @@ struct is_simple_type { std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value; }; diff --git a/xla/tsl/protobuf/dnn.proto b/xla/tsl/protobuf/dnn.proto index 09237a09e45b8..4a6d8fff6f72c 100644 --- a/xla/tsl/protobuf/dnn.proto +++ b/xla/tsl/protobuf/dnn.proto @@ -25,6 +25,7 @@ enum DataType { kF8E4M3 = 13; kF8E3M4 = 14; kF4E2M1FN = 15; + kF8E8M0FNU = 16; } // Describes how a convolution input or output layer's data is formatted. diff --git a/xla/tsl/python/lib/core/ml_dtypes.cc b/xla/tsl/python/lib/core/ml_dtypes.cc index 19dd8ddd2413a..a986efb7cca96 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.cc +++ b/xla/tsl/python/lib/core/ml_dtypes.cc @@ -77,6 +77,8 @@ struct MlDtypesInitInfo { py::dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")).num(); numpy_dtypes.float8_e5m2fnuz = py::dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")).num(); + numpy_dtypes.float8_e8m0fnu = + py::dtype::from_args(ml_dtypes.attr("float8_e8m0fnu")).num(); numpy_dtypes.int4 = py::dtype::from_args(ml_dtypes.attr("int4")).num(); numpy_dtypes.uint4 = py::dtype::from_args(ml_dtypes.attr("uint4")).num(); } catch (const std::exception& e) { @@ -95,6 +97,7 @@ struct MlDtypesInitInfo { numpy_dtypes.float8_e4m3b11fnuz == NPY_NOTYPE || numpy_dtypes.float8_e5m2 == NPY_NOTYPE || numpy_dtypes.float8_e5m2fnuz == NPY_NOTYPE || + numpy_dtypes.float8_e8m0fnu == NPY_NOTYPE || numpy_dtypes.int4 == NPY_NOTYPE || numpy_dtypes.uint4 == NPY_NOTYPE) { init_valid = false; } diff --git a/xla/tsl/python/lib/core/ml_dtypes.h b/xla/tsl/python/lib/core/ml_dtypes.h index c6523eaf38026..725d844c27bb4 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.h +++ b/xla/tsl/python/lib/core/ml_dtypes.h @@ -32,6 +32,7 @@ struct NumpyDtypes { int float8_e4m3fnuz; int float8_e5m2; int float8_e5m2fnuz; + int float8_e8m0fnu; int int4; int uint4; }; diff --git a/xla/types.h b/xla/types.h index f851ea5a66d59..15dd4f5d2d458 100644 --- a/xla/types.h +++ b/xla/types.h @@ -131,6 +131,8 @@ struct make_specialized_signed>> { template using make_specialized_signed_t = typename make_specialized_signed::type; +// has_negative_zero[_v] + template struct has_negative_zero : std::bool_constant::is_iec559> {}; @@ -144,6 +146,17 @@ struct has_negative_zero : std::bool_constant {}; template inline constexpr bool has_negative_zero_v = has_negative_zero::value; +// has_zero[_v] + +template +struct has_zero : std::bool_constant {}; + +template <> +struct has_zero : std::bool_constant {}; + +template +inline constexpr bool has_zero_v = has_zero::value; + } // namespace xla #endif // XLA_TYPES_H_ diff --git a/xla/util.cc b/xla/util.cc index 2d44ffeb1642f..023e09342f113 100644 --- a/xla/util.cc +++ b/xla/util.cc @@ -217,6 +217,11 @@ std::string RoundTripFpToString(tsl::float8_e3m4 value) { return result; } +std::string RoundTripFpToString(tsl::float8_e8m0fnu value) { + std::string result = GenericRoundTripFpToString(value); + return result; +} + std::string RoundTripFpToString(bfloat16 value) { std::string result = GenericRoundTripFpToString(value); RoundTripNanPayload(value, &result); diff --git a/xla/util.h b/xla/util.h index 0270dc398bcd7..6e3d60a7d0dea 100644 --- a/xla/util.h +++ b/xla/util.h @@ -440,6 +440,9 @@ std::string RoundTripFpToString(tsl::float8_e4m3fnuz value); // Returns a string which can losslessly round trip to a float8 E3M4. std::string RoundTripFpToString(tsl::float8_e3m4 value); +// Returns a string which can losslessly round trip to a float8 E8M0FNU. +std::string RoundTripFpToString(tsl::float8_e8m0fnu value); + // Returns a string which can losslessly round trip to a bfloat. std::string RoundTripFpToString(tsl::bfloat16 value); @@ -666,6 +669,12 @@ auto SignAndMagnitude(T x) { return std::make_pair(x_sign, x_abs_bits); } +template <> +inline auto SignAndMagnitude(tsl::float8_e8m0fnu x) { + uint8_t x_bits = Eigen::numext::bit_cast(x); + return std::make_pair(static_cast(0), x_bits); +} + template auto SignAndMagnitudeToTwosComplement(T sign, T magnitude) { static_assert(!std::numeric_limits::is_signed); @@ -680,6 +689,11 @@ auto ToSignMagnitude(T input) { return SignAndMagnitudeToTwosComplement(sign, magnitude); } +template <> +inline auto ToSignMagnitude(tsl::float8_e8m0fnu x) { + return Eigen::numext::bit_cast(x); +} + template constexpr int NanPayloadBits() { // Floating point types with signaling NaNs have payloads. diff --git a/xla/util_test.cc b/xla/util_test.cc index 7e3d12772b0a3..74dc0820c4f86 100644 --- a/xla/util_test.cc +++ b/xla/util_test.cc @@ -207,9 +207,9 @@ namespace { template void TotalOrderHelper(T x, T y) { auto x_sm = ToSignMagnitude(x); - bool x_sign = static_cast(Eigen::numext::signbit(x)); - bool y_sign = static_cast(Eigen::numext::signbit(y)); auto y_sm = ToSignMagnitude(y); + bool x_sign = static_cast(SignAndMagnitude(x).first); + bool y_sign = static_cast(SignAndMagnitude(y).first); if (x_sign && !y_sign) { EXPECT_LT(x_sm, y_sm) << x << " " << y; } @@ -338,6 +338,18 @@ TEST(UtilTest, TotalOrder_F8E3M4) { } } +TEST(UtilTest, TotalOrder_F8E8M0FNU) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e8m0fnu x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e8m0fnu y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + void PackInt4(absl::Span input, absl::Span output) { CHECK_EQ(output.size(), CeilOfRatio(input.size(), size_t{2})); for (size_t i = 0; i < input.size(); ++i) { diff --git a/xla/xla_data.proto b/xla/xla_data.proto index 884908ab43c52..87a4b3b35c049 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -600,8 +600,9 @@ message LiteralProto { bytes f8e4m3s = 28; bytes f8e5m2fnuzs = 24; bytes f8e5m2s = 19; + bytes f8e8m0fnus = 31; repeated int64 sparse_indices = 14; - // Next = 31 + // Next = 32 } message WindowDimension {