diff --git a/third_party/tsl/third_party/py/ml_dtypes/e8m0.patch b/third_party/tsl/third_party/py/ml_dtypes/e8m0.patch deleted file mode 100644 index add96113db620..0000000000000 --- a/third_party/tsl/third_party/py/ml_dtypes/e8m0.patch +++ /dev/null @@ -1,99 +0,0 @@ -diff --git a/include/float8.h b/include/float8.h -index 51d91fd..c53eb45 100644 ---- a/include/float8.h -+++ b/include/float8.h -@@ -1021,11 +1021,11 @@ struct numeric_limits_float8_e8m0fnu : public numeric_limits_float8_base { - static inline constexpr const int max_digits10 = - MaxDigits10FromDigits(digits); - // 2**-127 smallest valid normalized value.. -- static inline constexpr const int min_exponent = -127 + 1; -+ static inline constexpr const int min_exponent = -kExponentBias + 1; - static inline constexpr const int min_exponent10 = - MinExponent10FromMinExponent(min_exponent); - // 128 encoding using for NaN -- static inline constexpr const int max_exponent = 127; -+ static inline constexpr const int max_exponent = kExponentBias + 1; - static inline constexpr const int max_exponent10 = - MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); - static inline constexpr const bool is_iec559 = false; -@@ -1292,7 +1292,8 @@ struct Traits : public TraitsBase { - }; - - template --constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff) { -+constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff, -+ bool use_implicit_bit) { - // Round to nearest even by adding a bias term. - // Consider a bit pattern - // FFF...FLRTT...T, -@@ -1301,9 +1302,12 @@ constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff) { - // - L is 1, R is 1, OR - // - L is 0, R is 1, any T is one. - // We do this by adding L to a bit pattern consisting of all T = 1. -- Bits bias = roundoff == 0 -- ? 0 -- : ((bits >> roundoff) & 1) + (Bits{1} << (roundoff - 1)) - 1; -+ // -+ // When rounding to zero mantissa (E8M0 type), the L bit is implicitly 1 (do -+ // not use the exponent bits for rounding). Add only the R bit in this case. -+ Bits bias = !use_implicit_bit -+ ? ((bits >> roundoff) & 1) + (Bits{1} << (roundoff - 1)) - 1 -+ : Bits{1} << (roundoff - 1); - return bits + bias; - } - -@@ -1443,6 +1447,7 @@ struct ConvertImpl> kFromMantissaBits; -+ const bool to_zero_mantissa = kToMantissaBits == 0; - - // `To` supports more exponents near zero which means that some subnormal - // values in `From` may become normal. -@@ -1473,11 +1478,14 @@ struct ConvertImpl 0) { -+ if constexpr (kDigitShift >= 0) { - bits <<= kDigitShift; - } else { - if constexpr (!kTruncate) { -- bits = RoundBitsToNearestEven(bits, -kDigitShift); -+ // When converting float to e8m0, the bits represent a denormal, -+ // so don't use the implicit mantissa bit for rounding. -+ bits = RoundBitsToNearestEven( -+ bits, -kDigitShift, to_zero_mantissa && kExponentOffset != 0); - } - bits >>= -kDigitShift; - } -@@ -1514,8 +1522,8 @@ struct ConvertImpl> exponent_shift; - } -@@ -1532,7 +1540,8 @@ struct ConvertImpl { - } - - if constexpr (!kTruncate) { -- from_bits = RoundBitsToNearestEven(from_bits, 8); -+ from_bits = RoundBitsToNearestEven(from_bits, 8, false); - // Rounding can cause an overflow to infinity. Clamp to the largest finite - // value if saturation is requested. - if constexpr (kSaturate) { diff --git a/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl b/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl index 1478f3cada677..0047319ecd918 100644 --- a/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl +++ b/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl @@ -12,7 +12,6 @@ def repo(): tf_http_archive( name = "ml_dtypes", build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD", - patch_file = ["//third_party/py/ml_dtypes:e8m0.patch"], link_files = { "//third_party/py/ml_dtypes:ml_dtypes.tests.BUILD": "tests/BUILD.bazel", "//third_party/py/ml_dtypes:LICENSE": "LICENSE", diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index 594ad973003fd..2b376944964c1 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,5 +1,8 @@ # PJRT C API changelog +## 0.60 +* Added types F4E2M1FN and F8E8M0FNU. + ## 0.59 * Added ``PJRT_MemoryDescriptions_Extension``. diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index ad1b774bc7f15..c740e1df5e4d5 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 59 +#define PJRT_API_MINOR 60 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in diff --git a/xla/primitive_util_test.cc b/xla/primitive_util_test.cc index da1a6afdf3b87..eebb1a640870a 100644 --- a/xla/primitive_util_test.cc +++ b/xla/primitive_util_test.cc @@ -960,7 +960,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { << primitive_util::LowercasePrimitiveTypeName(to_type); } } -} +} // NOLINT(readability/fn_size) } // namespace } // namespace xla diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index fe928cfee7547..9c955af14714e 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -1436,7 +1436,8 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( // Cast to F16 first. Casts to F4E2M1FN must be from F16. if (from_type != F16) { operand_value = b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + operand_value, + llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext())); } return EmitF16ToF4e2m1fn(operand_value, b_); } @@ -1444,7 +1445,8 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( // Cast to F32 first. Casts to F8E8M0FNU must be from F32. if (from_type != F32) { operand_value = b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(F32, module_)); + operand_value, + llvm_ir::PrimitiveTypeToIrType(F32, module_->getContext())); } return EmitF32ToF8e8m0fnu(operand_value, b_); } diff --git a/xla/service/gpu/fusions/triton/triton_support_test.cc b/xla/service/gpu/fusions/triton/triton_support_test.cc index 5d0c696ccc980..15a94457268d0 100644 --- a/xla/service/gpu/fusions/triton/triton_support_test.cc +++ b/xla/service/gpu/fusions/triton/triton_support_test.cc @@ -550,9 +550,18 @@ INSTANTIATE_TEST_SUITE_P( using ReduceTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; +static std::string_view init_value(PrimitiveType dtype) { + if (dtype == C64 || dtype == C128) { + return "(0, 0)"; + } else if (dtype == F8E8M0FNU) { + return "1e-40"; + } else { + return "0"; + } +} + TEST_P(ReduceTest, IsTritonSupportedReduction) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( add { @@ -567,7 +576,7 @@ ENTRY triton_computation { ROOT reduce = $0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=add })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -599,7 +608,6 @@ TEST_P( ReduceTest, UnsupportedReduceWithMoreThanOneReduceDimensionsFailsGracefullyWithTriton) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( add { @@ -614,7 +622,7 @@ ENTRY triton_computation { ROOT reduce = $0[2] reduce(parameter_0, constant_0), dimensions={1,2}, to_apply=add })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -624,7 +632,6 @@ ENTRY triton_computation { TEST_P(ReduceTest, IsTritonSupportedReduceWithNonLastReduceDimension) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( add { @@ -638,7 +645,7 @@ ENTRY triton_computation { constant_0 = $0[] constant($1) ROOT reduce = $0[127] reduce(parameter_0, constant_0), dimensions={0}, to_apply=add })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -649,7 +656,6 @@ ENTRY triton_computation { TEST_P(ReduceTest, UnsupportedReduceWithMoreThanOneOperandsFailsGracefullyWithTriton) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( add { @@ -670,7 +676,7 @@ ENTRY triton_computation { dimensions={1}, to_apply=add ROOT reduce = $0[125] get-tuple-element(tuple), index=0 })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -701,7 +707,6 @@ ENTRY triton_computation { TEST_P(ReduceTest, UnsupportedReductionComputationFailsGracefullyWithTriton) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( custom_call { @@ -716,7 +721,7 @@ ENTRY triton_computation { ROOT reduce = $0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=custom_call })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -740,7 +745,6 @@ using ReductionComputationTest = // computation and in regular HLO. See triton_support.cc for more details. TEST_P(ReductionComputationTest, DifferentBinaryOps) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute( R"( reduce_computation { @@ -755,7 +759,7 @@ ENTRY triton_computation { ROOT reduce = $0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=reduce_computation })", - "$0", HloOpcodeString(opcode), dtype_is_complex ? "(0, 0)" : "0"); + "$0", HloOpcodeString(opcode), init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -1115,13 +1119,12 @@ TEST_P(ConstantTest, ConstantEffectiveScalar) { // The IsTritonSupportedReduction effectively tests the scalar constant // support. auto [data_type, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( ENTRY triton_computation { ROOT const = $0[1,1] constant({{$1}}) })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( kHloTestTemplate, data_type, @@ -1133,13 +1136,12 @@ TEST_P(ConstantTest, Constant2D) { // The IsTritonSupportedReduction effectively tests the scalar constant // support. auto [data_type, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( ENTRY triton_computation { ROOT const = $0[3,3] constant({{$1,$1,$1},{$1,$1,$1},{$1,$1,$1}}) })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( kHloTestTemplate, data_type, diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index 95d6525ab300b..3e0794c5eed35 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -2146,15 +2146,15 @@ XLA_TYPED_TEST(ConvertTestT, ConvertF8e8m0fnuRoundtripExhaustive) { } XLA_TYPED_TEST(ConvertTestT, ConvertF8e8m0fnuRoundtripExhaustive2) { -#ifdef XLA_TEST_BACKEND_CPU - // This test is disabled on CPU, as converting 0x1p-127 from double to float - // using CVTSD2SS on x64 results in an underflow (even though the result is - // representable as denormalized float32). - if (std::is_same_v) { - GTEST_SKIP() << "Skipping test for double precision floating point that " - "loses denormal value during conversion"; - } -#endif + if (this->client_->platform()->Name() == "Host") { + // This test is disabled on CPU, as converting 0x1p-127 from double to float + // using CVTSD2SS on x64 results in an underflow (even though the result is + // representable as denormalized float32). + if (std::is_same_v) { + GTEST_SKIP() << "Skipping test for double precision floating point that " + "loses denormal value during conversion"; + } + } // Convert from supported floating point type to FP8. XlaBuilder builder(this->TestName());