diff --git a/third_party/tsl/tsl/platform/BUILD b/third_party/tsl/tsl/platform/BUILD index 9f3b135245a40..eb208e6ef65ba 100644 --- a/third_party/tsl/tsl/platform/BUILD +++ b/third_party/tsl/tsl/platform/BUILD @@ -985,7 +985,6 @@ cc_library( deps = [ "@ml_dtypes//:float8", "@ml_dtypes//:intn", - "@ml_dtypes//:mxfloat", ], ) diff --git a/third_party/tsl/tsl/platform/ml_dtypes.h b/third_party/tsl/tsl/platform/ml_dtypes.h index a03fa02447f3c..a6a1b56af88ad 100644 --- a/third_party/tsl/tsl/platform/ml_dtypes.h +++ b/third_party/tsl/tsl/platform/ml_dtypes.h @@ -18,10 +18,8 @@ limitations under the License. #include "ml_dtypes/include/float8.h" // from @ml_dtypes #include "ml_dtypes/include/intn.h" // from @ml_dtypes -#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes namespace tsl { -using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn; using float8_e3m4 = ::ml_dtypes::float8_e3m4; using float8_e4m3 = ::ml_dtypes::float8_e4m3; using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; @@ -29,7 +27,6 @@ using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; using float8_e5m2 = ::ml_dtypes::float8_e5m2; using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz; -using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu; using int1 = ::ml_dtypes::int1; using uint1 = ::ml_dtypes::uint1; diff --git a/xla/array2d_test.cc b/xla/array2d_test.cc index c62f6e882713e..921da30256fa3 100644 --- a/xla/array2d_test.cc +++ b/xla/array2d_test.cc @@ -219,34 +219,6 @@ TEST(Array2dTest, LinspaceF8E3M4) { EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); } -TEST(Array2dTest, LinspaceF4E2M1FN) { - auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); - - EXPECT_EQ(arr->n1(), 3); - EXPECT_EQ(arr->n2(), 2); - - EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); - EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); - EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); - EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.0); // 2.5 rounded down - EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); - EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 4.0); // 3.5 rounded up -} - -TEST(Array2dTest, LinspaceF8E8M0FNU) { - auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); - - EXPECT_EQ(arr->n1(), 3); - EXPECT_EQ(arr->n2(), 2); - - EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); - EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 2.0); // 1.5 rounded up - EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); - EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.0); // 2.5 rounded down - EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 4.0); // 3.0 rounded up - EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 4.0); // 3.5 rounded up -} - TEST(Array2dTest, Stringification) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); const std::string expected = R"([[1, 1.5], diff --git a/xla/backends/gpu/codegen/transforms/expand_float_ops.cc b/xla/backends/gpu/codegen/transforms/expand_float_ops.cc index ff2ce86227798..81cb99d66f82d 100644 --- a/xla/backends/gpu/codegen/transforms/expand_float_ops.cc +++ b/xla/backends/gpu/codegen/transforms/expand_float_ops.cc @@ -163,13 +163,7 @@ int GetSignificandBits(mlir::FloatType ty) { } int GetExponentBias(mlir::FloatType ty) { - return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()) - - ty.isFloat8E8M0FNU(); // No zero exponent for E8M0. -} - -bool IsFNUZ(mlir::FloatType ty) { - return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() || - ty.isFloat8E5M2FNUZ(); + return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()); } Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { @@ -181,7 +175,7 @@ Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { return b.create(ma::CmpFPredicate::OEQ, value, inf); } - assert(ty.getIntOrFloatBitWidth() <= 8); + assert(ty.getIntOrFloatBitWidth() == 8); // F8E5M2, F8E4M3, F8E3M4 are the only 8 bit float with infinities. if (ty.isFloat8E5M2()) { Val bits{b.create(b.getI8Type(), value), &b}; @@ -202,9 +196,6 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { if (mlir::LLVM::isCompatibleOuterType(ty)) { return b.create(ma::CmpFPredicate::UNO, value, value); } - if (ty.isFloat4E2M1FN()) { - return b.create(false, b.getI1Type()); - } assert(ty.getIntOrFloatBitWidth() == 8); Val bits{b.create(b.getI8Type(), value), &b}; @@ -216,8 +207,6 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { return (bits & 0b0111'1111) == 0b0111'1111; } else if (ty.isFloat8E3M4()) { return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'0000); - } else if (ty.isFloat8E8M0FNU()) { - return bits == 0xFF; } return bits == 0x80; } @@ -292,18 +281,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, auto to_int_ty = b.getIntegerType(to_ty.getIntOrFloatBitWidth()); mlir::IntegerType wide_int_ty; - if (from_ty.getWidth() <= 8 && to_ty.getWidth() <= 8) { + if (from_ty.getWidth() == 8 && to_ty.getWidth() == 8) { wide_int_ty = b.getI16Type(); } else { wide_int_ty = b.getIntegerType( std::max(from_int_ty.getWidth(), to_int_ty.getWidth())); - // Avoid overflow for bit shifts. - auto may_overflow = [&](mlir::Type a, mlir::Type b) { - return a.isFloat8E8M0FNU() && b.isF16(); - }; - if (may_overflow(from_ty, to_ty) || may_overflow(to_ty, from_ty)) { - wide_int_ty = b.getI32Type(); - } } auto convert_int = [&](mlir::Type ty, Value v) -> Val { if (v.getType() == ty) { @@ -318,49 +300,34 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, int64_t exp_offset = to_bias - from_bias; int digit_shift = to_mantissa - from_mantissa; - int from_width = value.getType().getIntOrFloatBitWidth(); - Val from_bits{b.create(b.getIntegerType(from_width), value), - &b}; - if (from_width < 8) { - from_bits = convert_int(b.getIntegerType(8), from_bits); - } + Val from_bits{ + b.create( + b.getIntegerType(value.getType().getIntOrFloatBitWidth()), value), + &b}; auto cst = [&](mlir::Type ty, int64_t n) -> Val { return {b.create(n, ty), &b}; }; // Shift bits to destination type, without sign bit. - Val from_sign_bit; - if (!from_ty.isFloat8E8M0FNU()) { - from_sign_bit = from_bits.shrui(from_width - 1) != 0; - from_bits = from_bits & ((1ULL << (from_width - 1)) - 1); - } + Val from_sign_bit = + from_bits.shrui(value.getType().getIntOrFloatBitWidth() - 1) != 0; + + from_bits = + from_bits & ((1ULL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1); + + Value result_is_inf = IsInf(value, b); + Value input_is_nan = IsNaN(value, b); auto cst_bits = [&](llvm::APFloat f) { return cst(b.getIntegerType(llvm::APFloat::getSizeInBits(f.getSemantics())), f.bitcastToAPInt().getZExtValue()); }; - Value to_nan; - Value to_inf; - Val to_zero; - - // MX float types have neither infinities nor NaNs. - if (to_ty.isFloat4E2M1FN()) { - to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); - to_nan = to_zero | 0x8; - to_inf = cst_bits(llvm::APFloat::getLargest(to_ty.getFloatSemantics())); - } else if (to_ty.isFloat8E8M0FNU()) { - to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); - to_inf = to_nan; - to_zero = Val{to_nan, &b}; - } else { - to_inf = cst_bits(llvm::APFloat::getInf(to_ty.getFloatSemantics())); - to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); - to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); - } + Value to_inf = cst_bits(llvm::APFloat::getInf(to_ty.getFloatSemantics())); + Value to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); + Val to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); - auto round_bits_to_nearest_even = [&](Val bits, Val roundoff, - bool use_implicit_bit = false) { + auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) { assert(bits.value.getType() == roundoff.value.getType()); // Round to nearest even by adding a bias term. // Consider a bit pattern @@ -370,10 +337,9 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, // - L is 1, R is 1, OR // - L is 0, R is 1, any T is one. // We do this by adding L to a bit pattern consisting of all T = 1. - Val bias = !use_implicit_bit - ? (bits.shrui(roundoff) & 1) + - (bits.MakeConstant(1).shl(roundoff - 1) - 1) - : bits.MakeConstant(1).shl(roundoff - 1); + Val rounded = (bits.shrui(roundoff) & 1) + + (bits.MakeConstant(1).shl(roundoff - 1) - 1); + Val bias{b.create(roundoff == 0, roundoff, rounded), &b}; return bits + bias; }; @@ -383,11 +349,9 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, // Round the mantissa if it is shrinking. Val rounded_from_bits = convert_int(wide_int_ty, from_bits); if (digit_shift < 0) { - rounded_from_bits = - round_bits_to_nearest_even( - rounded_from_bits, rounded_from_bits.MakeConstant(-digit_shift), - /*use_implicit_bit=*/to_mantissa == 0) & - ~((1ll << (-digit_shift)) - 1); + rounded_from_bits = round_bits_to_nearest_even( + from_bits, from_bits.MakeConstant(-digit_shift)) & + ~((1ll << (-digit_shift)) - 1); } // Re-bias the exponent. @@ -430,10 +394,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, Val bits = convert_int(wide_int_ty, from_bits); // Determine exponent in target type. - Value clz = convert_int( - i32_ty, b.create(from_bits)); - Value msb = cst(i32_ty, std::max(from_width, 8) - 1) - clz; - Value normalization_factor = cst(i32_ty, from_mantissa) - msb; + Value normalization_factor = + convert_int(i32_ty, + b.create(from_bits)) - + (from_int_ty.getWidth() - from_mantissa - 1); Val biased_exponent = cst(i32_ty, exp_offset + 1) - normalization_factor; // If the result is subnormal, adjust the subnormal bits to account for @@ -454,12 +418,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, Value biased_exp_sle_zero = biased_exponent.cmp(CmpIPredicate::sle, 0); bits.value = b.create(biased_exp_sle_zero, subnormal_bits, normal_bits); - if (digit_shift >= 0) { + if (digit_shift > 0) { bits = bits.shl(digit_shift); } else { - bits = round_bits_to_nearest_even( - bits, bits.MakeConstant(-digit_shift), - /*use_implicit_bit=*/to_mantissa == 0 && exp_offset != 0); + bits = round_bits_to_nearest_even(bits, bits.MakeConstant(-digit_shift)); bits = bits.shrui(-digit_shift); } bits = convert_int(to_int_ty, bits); @@ -468,11 +430,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, } else if (to_min_exp > from_min_exp) { // `To` supports fewer exponents near zero which means that some values in // `From` may become subnormal. - Val biased_to_exp = biased_from_exp + (to_bias - from_bias); + Val unbiased_exp = biased_from_exp - from_bias; + Val biased_to_exp = unbiased_exp + to_bias; // Subnormals and zero. // Round and shift mantissa down. - Val from_has_leading_one = - !from_ty.isFloat8E8M0FNU() ? biased_from_exp != 0 : cst(i32_ty, 1); + Val from_has_leading_one = biased_from_exp != 0; Val from_has_leading_one_i32 = convert_int(i32_ty, from_has_leading_one); from_has_leading_one = convert_int(from_int_ty, from_has_leading_one); Val exponent_shift_i32 = @@ -507,35 +469,31 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, result); } - Value result_is_inf = IsInf(value, b); - Value input_is_nan = IsNaN(value, b); + // Handle types with no unsigned zero. + auto is_nuz = [](mlir::FloatType ty) { + return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() || + ty.isFloat8E5M2FNUZ(); + }; - if (to_ty.isFloat8E8M0FNU()) { - // Converting a negative number to E8M0 results in NaN. - input_is_nan = from_sign_bit | input_is_nan; - } else if (IsFNUZ(to_ty)) { + if (is_nuz(to_ty)) { // Clear the sign bit if the result is zero (the output has no negative - // zero). Handle the edge case when the input is zero and the result is not. - Val result_is_non_zero = - (digit_shift > 0 ? from_bits : Val{result, &b}) != 0; + // zero). + Val result_is_non_zero = Val{result, &b} != 0; from_sign_bit = from_sign_bit & result_is_non_zero; - } else if (IsFNUZ(from_ty)) { + } else if (is_nuz(from_ty)) { // Clear the sign bit if the input is NaN (it's positive but encoded as // negative 0). from_sign_bit = from_sign_bit ^ input_is_nan; } - if (!from_ty.isFloat8E8M0FNU()) { - result = b.create(from_bits == 0, to_zero, result); - } result = b.create(result_is_inf, to_inf, result); + result = b.create(from_bits == 0, to_zero, result); result = b.create(input_is_nan, to_nan, result); + Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1)); + // Insert sign bit. - if (!from_ty.isFloat8E8M0FNU()) { - Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1)); - result = b.create(from_sign_bit, neg_result, result); - } + result = b.create(from_sign_bit, neg_result, result); result = b.create(to_ty, result); return result; } @@ -548,8 +506,8 @@ struct RewriteTruncFPattern : public mlir::OpRewritePattern { using FloatValue = mlir::TypedValue; auto src = mlir::cast(op.getOperand()); auto dst_ty = mlir::cast(op.getType()); - if (dst_ty.getWidth() > 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) truncf"); + if (dst_ty.getWidth() != 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit truncf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -566,8 +524,8 @@ struct RewriteExtFPattern : public mlir::OpRewritePattern { using FloatValue = mlir::TypedValue; auto src = mlir::cast(op.getOperand()); auto dst_ty = mlir::cast(op.getType()); - if (src.getType().getWidth() > 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) extf"); + if (src.getType().getWidth() != 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit extf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -586,8 +544,8 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { auto lhs = mlir::cast(op.getLhs()); auto rhs = mlir::cast(op.getRhs()); - if (lhs.getType().getWidth() > 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) cmpf"); + if (lhs.getType().getWidth() != 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit cmpf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -595,16 +553,16 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { llvm::APFloat rhs_cst(rhs.getType().getFloatSemantics()); if (op.getPredicate() == ma::CmpFPredicate::UNE && mlir::matchPattern(rhs, mlir::m_ConstantFloat(&rhs_cst))) { - mlir::Type int_ty = rewriter.getIntegerType(lhs.getType().getWidth()); - Val int_value{b.create(int_ty, lhs), &b}; + Val int_value{b.create(rewriter.getI8Type(), lhs), &b}; int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue(); // If we're comparing to +-0, compare the absolute values. - if (rhs_cst.isZero() && !IsFNUZ(lhs.getType())) { - int64_t mask = (1 << (lhs.getType().getWidth() - 1)) - 1; - int_value = int_value & mask; - constant &= mask; + if (rhs_cst.isZero() && + (lhs.getType().isFloat8E3M4() || lhs.getType().isFloat8E4M3() || + lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) { + int_value = int_value & 0x7f; + constant &= 0x7f; } - auto cst = b.create(constant, int_ty); + auto cst = b.create(constant, rewriter.getI8Type()); rewriter.replaceOpWithNewOp(op, ma::CmpIPredicate::ne, int_value, cst); return mlir::success(); @@ -628,23 +586,18 @@ struct RewriteAbsFPattern : public mlir::OpRewritePattern { auto src = mlir::cast(op.getOperand()); // LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16. // Once that's removed, remove the code for BF16 here. - if (src.getType().getWidth() > 8 && !src.getType().isBF16()) { - return rewriter.notifyMatchFailure(op, - "not an f8 (or less) or bf16 absf"); + if (src.getType().getWidth() != 8 && !src.getType().isBF16()) { + return rewriter.notifyMatchFailure(op, "not an f8 or bf16 absf"); } - - // If type is unsigned (E8M0), the operation is no-op. - if (!llvm::APFloat::semanticsHasSignedRepr( - src.getType().getFloatSemantics())) { - rewriter.replaceAllOpUsesWith(op, op.getOperand()); - return mlir::success(); - } - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth()); Val value{b.create(i_ty, src), &b}; - int64_t mask = (1ull << (src.getType().getWidth() - 1)) - 1; - value = value & mask; + if (src.getType().getWidth() == 8) { + value = value & 0x7f; + } else { + CHECK(src.getType().isBF16()); + value = value & 0x7fff; + } rewriter.replaceOpWithNewOp(op, src.getType(), value); return mlir::success(); } @@ -656,8 +609,8 @@ struct RewriteIToFpPattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter& rewriter) const override { - if (op.getType().getIntOrFloatBitWidth() > 8) { - return rewriter.notifyMatchFailure(op, "not an f8 (or less) itofp"); + if (op.getType().getIntOrFloatBitWidth() != 8) { + return rewriter.notifyMatchFailure(op, "not an f8 itofp"); } Value to_float = rewriter.create(op.getLoc(), rewriter.getF32Type(), op.getIn()); @@ -672,8 +625,8 @@ struct RewriteFpToIPattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter& rewriter) const override { - if (op.getIn().getType().getIntOrFloatBitWidth() > 8) { - return rewriter.notifyMatchFailure(op, "not an f8 (or less) fptoi"); + if (op.getIn().getType().getIntOrFloatBitWidth() != 8) { + return rewriter.notifyMatchFailure(op, "not an f8 fptoi"); } Value to_f32 = rewriter.create( op.getLoc(), rewriter.getF32Type(), op.getIn()); diff --git a/xla/backends/gpu/codegen/transforms/lower_tensors.cc b/xla/backends/gpu/codegen/transforms/lower_tensors.cc index 31737323d78e4..38e3671f9613f 100644 --- a/xla/backends/gpu/codegen/transforms/lower_tensors.cc +++ b/xla/backends/gpu/codegen/transforms/lower_tensors.cc @@ -297,8 +297,7 @@ std::tuple GetI4IndexAndNibble(Value linear_index, mlir::LLVM::GEPOp CreateGep(TypedValue tensor, Value linear_index, mlir::ImplicitLocOpBuilder& b) { Type element_type = tensor.getType().getElementType(); - if (element_type.isIntOrFloat() && - element_type.getIntOrFloatBitWidth() == 4) { + if (element_type == b.getI4Type()) { element_type = b.getI8Type(); } auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext()); @@ -327,8 +326,7 @@ struct RewriteTensorExtract : OpRewritePattern { auto linear_index = GetLinearIndex(op.getIndices(), b); Type element_type = op.getTensor().getType().getElementType(); Value is_low_nibble = nullptr; - if (element_type.isIntOrFloat() && - element_type.getIntOrFloatBitWidth() == 4) { + if (element_type == rewriter.getI4Type()) { std::tie(linear_index, is_low_nibble) = GetI4IndexAndNibble(linear_index, b); } @@ -343,7 +341,7 @@ struct RewriteTensorExtract : OpRewritePattern { auto high_value = b.create( load, b.create(4, load.getType())); load = b.create( - rewriter.getI4Type(), + op.getType(), b.create(is_low_nibble, load, high_value)); } @@ -379,7 +377,6 @@ struct RewriteTransferRead : OpRewritePattern { auto source = mlir::dyn_cast>( op.getSource()); - mlir::Type source_element_type = source.getType().getElementType(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto linear_index = GetLinearIndex(op.getIndices(), b); @@ -388,9 +385,7 @@ struct RewriteTransferRead : OpRewritePattern { if (vector_type.getElementType().isInteger(1)) { vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type()); } - mlir::Type gep_element_type = vector_type.getElementType(); - if (gep_element_type.isIntOrFloat() && - gep_element_type.getIntOrFloatBitWidth() == 4) { + if (op.getVectorType().getElementType().isInteger(4)) { linear_index = b.create( linear_index, b.create(1, linear_index.getType())); @@ -402,12 +397,11 @@ struct RewriteTransferRead : OpRewritePattern { auto loaded = b.create(llvm_vector_type, gep).getResult(); - if (source_element_type.isInteger(1)) { + if (source.getType().getElementType().isInteger(1)) { Value zero = b.create( mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0))); loaded = b.create(arith::CmpIPredicate::ne, loaded, zero); - } else if (source_element_type.isIntOrFloat() && - source_element_type.getIntOrFloatBitWidth() == 4) { + } else if (source.getType().getElementType().isInteger(4)) { // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the // elements. loaded = PermutePairsInVector(loaded, b); @@ -436,8 +430,7 @@ struct RewriteTensorInsert : OpRewritePattern { auto scalar_value = op.getScalar(); // For i4 we store 2 values into one byte. This needs special handling here. - if (tensor_dest.getType().getElementType().isIntOrFloat() && - tensor_dest.getType().getElementType().getIntOrFloatBitWidth() == 4) { + if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) { // We need to use directly op.getDest() as input, otherwise the following // rewrite might remove the only user of it. tensor_dest = op.getDest(); @@ -455,10 +448,6 @@ struct RewriteTensorInsert : OpRewritePattern { auto tensor_dest_i8 = b.create(tensor_ty, tensor_dest) .getResult(0); - if (scalar_value.getType() != rewriter.getI4Type()) { - scalar_value = - b.create(rewriter.getI4Type(), scalar_value); - } scalar_value = b.create(ty, scalar_value); // We need AtomicRMWOp because it can happen that different threads try to @@ -518,14 +507,12 @@ struct RewriteTransferWrite : OpRewritePattern { auto linear_index = GetLinearIndex(op.getIndices(), b); mlir::Value vector_value = op.getVector(); - mlir::Type vector_element_type = op.getVectorType().getElementType(); - if (vector_element_type.isInteger(1)) { + if (op.getVectorType().getElementType().isInteger(1)) { vector_value = b.create( op.getVectorType().cloneWith(std::nullopt, b.getI8Type()), vector_value); } - if (vector_element_type.isIntOrFloat() && - vector_element_type.getIntOrFloatBitWidth() == 4) { + if (op.getVectorType().getElementType().isInteger(4)) { linear_index = b.create( linear_index, b.create(1, linear_index.getType())); @@ -590,19 +577,21 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value, // Needed to support complex element type. mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_element_type = converter.convertType(element_type); - if (element_type.isIntOrFloat() && - element_type.getIntOrFloatBitWidth() == 4) { - num_elements = CeilOfRatio(num_elements, 2); - llvm_element_type = b.getI8Type(); - auto unpacked_data = - mlir::cast(value).getRawData(); - std::vector packed_data(num_elements); - absl::Span packed_data_span = - absl::MakeSpan(packed_data.data(), packed_data.size()); - PackIntN(4, unpacked_data, packed_data_span); - value = mlir::DenseElementsAttr::getFromRawBuffer( - mlir::RankedTensorType::get({num_elements}, llvm_element_type), - packed_data); + if (mlir::isa(element_type)) { + int bit_width = mlir::cast(element_type).getWidth(); + if (bit_width == 4) { + num_elements = CeilOfRatio(num_elements, 2); + llvm_element_type = b.getI8Type(); + auto unpacked_data = + mlir::cast(value).getRawData(); + std::vector packed_data(num_elements); + absl::Span packed_data_span = + absl::MakeSpan(packed_data.data(), packed_data.size()); + PackIntN(4, unpacked_data, packed_data_span); + value = mlir::DenseElementsAttr::getFromRawBuffer( + mlir::RankedTensorType::get({num_elements}, llvm_element_type), + packed_data); + } } auto array_ty = mlir::LLVM::LLVMArrayType::get(llvm_element_type, num_elements); diff --git a/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir b/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir index dea8988d474b0..442fe5e929157 100644 --- a/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir +++ b/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir @@ -115,53 +115,3 @@ module { // CHECK: %[[EXT:.*]] = arith.extf {{.*}} : bf16 to f32 // CHECK: arith.truncf %[[EXT]] : f32 to f16 // CHECK-NOT: arith.truncf - -// ----- - -module { - func.func @f4_to_f16(%arg0: f4E2M1FN) -> f16 { - %ret = arith.extf %arg0 : f4E2M1FN to f16 - return %ret : f16 - } -} - -// CHECK-LABEL: @f4_to_f16 -// CHECK-NOT: arith.extf - -// ----- - -module { - func.func @f16_to_f4(%arg0: f16) -> f4E2M1FN { - %ret = arith.truncf %arg0 : f16 to f4E2M1FN - return %ret : f4E2M1FN - } -} - -// CHECK-LABEL: @f16_to_f4 -// CHECK-NOT: arith.truncf - -// ----- - -module { - func.func @f4_abs(%arg0: f4E2M1FN) -> f4E2M1FN { - %ret = math.absf %arg0 : f4E2M1FN - return %ret : f4E2M1FN - } -} - -// CHECK-LABEL: @f4_abs -// CHECK-NOT: math.absf -// CHECK: arith.constant 7 : i4 - -// ----- - -module { - func.func @e8m0_abs(%arg0: f8E8M0FNU) -> f8E8M0FNU { - %ret = math.absf %arg0 : f8E8M0FNU - return %ret : f8E8M0FNU - } -} - -// CHECK-LABEL: @e8m0_abs -// CHECK-NOT: math.absf -// CHECK: return %arg0 diff --git a/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir b/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir index 864f68d1da6f4..646c7a00ff756 100644 --- a/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir +++ b/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir @@ -763,44 +763,4 @@ func.func @for_op(%arg0: tensor<500xf32>) -> f32 { // CHECK-LABEL: @for_op // CHECK: scf.for {{.*}} -> (vector<4xf32>) { -// CHECK-NEXT: scf.for {{.*}} -> (vector<4xf32>) { - -// ----- - -func.func @f4_constant(%arg0: tensor<3xf4E2M1FN>, %arg1: index) -> f4E2M1FN { - %cst = arith.constant dense<[0.5, -0.5, 2.5]> : tensor<3xf4E2M1FN> - %extracted = tensor.extract %arg0[%arg1] : tensor<3xf4E2M1FN> - %extracted_0 = tensor.extract %cst[%arg1] : tensor<3xf4E2M1FN> - %0 = arith.addf %extracted, %extracted_0 : f4E2M1FN - return %0 : f4E2M1FN -} -// CHECK: llvm.mlir.global private constant -// CHECK-SAME: dense<[25, 64]> -// CHECK-LABEL: @f4_constant - -// ----- - -func.func @transfer_read_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}) -> vector<2xf4E2M1FN> { - %c16 = arith.constant 16 : index - %c0 = arith.constant 0.0 : f4E2M1FN - %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf4E2M1FN>, vector<2xf4E2M1FN> - func.return %out : vector<2xf4E2M1FN> -} -// CHECK-LABEL: @transfer_read_f4 -// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[8] -// CHECK: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xi4> -// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xi4> to vector<2xf4E2M1FN> -// CHECK: return %[[OUT]] : vector<2xf4E2M1FN> - -// ----- - -func.func @transfer_write_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}, - %arg1: vector<2xf4E2M1FN>) -> tensor<43xf4E2M1FN> { - %c10 = arith.constant 10 : index - %out = vector.transfer_write %arg1, %arg0[%c10] : vector<2xf4E2M1FN>, tensor<43xf4E2M1FN> - func.return %out : tensor<43xf4E2M1FN> -} -// CHECK-LABEL: @transfer_write_f4 -// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %arg0[5] : (!llvm.ptr) -> !llvm.ptr, i8 -// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xf4E2M1FN> to vector<2xi4> -// CHECK: llvm.store %[[OUT]], %[[PTR]] : vector<2xi4>, !llvm.ptr +// CHECK-NEXT: scf.for {{.*}} -> (vector<4xf32>) { \ No newline at end of file diff --git a/xla/comparison_util.h b/xla/comparison_util.h index 44f0dd48640bb..5a21595da4d74 100644 --- a/xla/comparison_util.h +++ b/xla/comparison_util.h @@ -193,13 +193,8 @@ class Comparison { // -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN // Reference: // https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations - if constexpr (std::numeric_limits::is_signed) { - using R = SignedIntegerTypeForSizeType; - return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); - } else { - using R = UnsignedIntegerTypeForSizeType; - return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); - } + using R = SignedIntegerTypeForSizeType; + return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); } } // Applies the comparison from this Comparison's direction and ordering. diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h index 9787476f8f7ea..389d2d2a9a7ae 100644 --- a/xla/ffi/api/api.h +++ b/xla/ffi/api/api.h @@ -131,8 +131,6 @@ inline std::ostream& operator<<(std::ostream& os, return os << "C128"; case XLA_FFI_DataType_TOKEN: return os << "TOKEN"; - case XLA_FFI_DataType_F4E2M1FN: - return os << "F4E2M1FN"; case XLA_FFI_DataType_F8E5M2: return os << "F8E5M2"; case XLA_FFI_DataType_F8E3M4: @@ -147,8 +145,6 @@ inline std::ostream& operator<<(std::ostream& os, return os << "F8E5M2FNUZ"; case XLA_FFI_DataType_F8E4M3FNUZ: return os << "F8E4M3FNUZ"; - case XLA_FFI_DataType_F8E8M0FNU: - return os << "F8E8M0FNU"; } } diff --git a/xla/ffi/api/c_api.h b/xla/ffi/api/c_api.h index bf8cb7d1a8ad1..8d6f1095fad24 100644 --- a/xla/ffi/api/c_api.h +++ b/xla/ffi/api/c_api.h @@ -201,8 +201,6 @@ typedef enum { XLA_FFI_DataType_F8E4M3B11FNUZ = 23, XLA_FFI_DataType_F8E5M2FNUZ = 24, XLA_FFI_DataType_F8E4M3FNUZ = 25, - XLA_FFI_DataType_F4E2M1FN = 32, - XLA_FFI_DataType_F8E8M0FNU = 33, } XLA_FFI_DataType; // LINT.ThenChange(ffi_test.cc) diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index aeeab1d505ab6..f264451da3473 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -79,8 +79,6 @@ enum class DataType : uint8_t { F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ, F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ, F8E3M4 = XLA_FFI_DataType_F8E3M4, - F4E2M1FN = XLA_FFI_DataType_F4E2M1FN, - F8E8M0FNU = XLA_FFI_DataType_F8E8M0FNU, }; // Create aliases in ::xla::ffi namespace for all DataTypes, for consistency @@ -108,8 +106,6 @@ inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ; inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ; inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ; inline constexpr DataType F8E3M4 = DataType::F8E3M4; -inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN; -inline constexpr DataType F8E8M0FNU = DataType::F8E8M0FNU; inline std::ostream& operator<<(std::ostream& os, const DataType dtype) { return os << static_cast(dtype); @@ -131,8 +127,6 @@ constexpr size_t ByteWidth(DataType dtype) { case DataType::F8E5M2FNUZ: case DataType::F8E4M3FNUZ: case DataType::F8E3M4: - case DataType::F4E2M1FN: - case DataType::F8E8M0FNU: return 1; case DataType::S16: case DataType::U16: diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index 3c51a0966ae02..f09588b9e986a 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -129,7 +129,6 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN)); - EXPECT_EQ(encoded(PrimitiveType::F4E2M1FN), encoded(DataType::F4E2M1FN)); EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN)); @@ -138,7 +137,6 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4)); - EXPECT_EQ(encoded(PrimitiveType::F8E8M0FNU), encoded(DataType::F8E8M0FNU)); } TEST(FfiTest, DataTypeByteWidth) { @@ -181,8 +179,6 @@ TEST(FfiTest, DataTypeByteWidth) { EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::C128), ByteWidth(DataType::C128)); - EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F4E2M1FN), - ByteWidth(DataType::F4E2M1FN)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2), ByteWidth(DataType::F8E5M2)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3), @@ -197,8 +193,6 @@ TEST(FfiTest, DataTypeByteWidth) { ByteWidth(DataType::F8E4M3FNUZ)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4), ByteWidth(DataType::F8E3M4)); - EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E8M0FNU), - ByteWidth(DataType::F8E8M0FNU)); } TEST(FfiTest, ErrorEnumValue) { diff --git a/xla/ffi/call_frame.cc b/xla/ffi/call_frame.cc index 7bcb14da445e8..3fb2ac3c7786f 100644 --- a/xla/ffi/call_frame.cc +++ b/xla/ffi/call_frame.cc @@ -264,7 +264,6 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::C64: case PrimitiveType::C128: case PrimitiveType::TOKEN: - case PrimitiveType::F4E2M1FN: case PrimitiveType::F8E5M2: case PrimitiveType::F8E4M3: case PrimitiveType::F8E4M3FN: @@ -272,7 +271,6 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::F8E5M2FNUZ: case PrimitiveType::F8E4M3FNUZ: case PrimitiveType::F8E3M4: - case PrimitiveType::F8E8M0FNU: return static_cast(primitive_type); default: DCHECK(false) << "Unsupported primitive type " diff --git a/xla/fp_util_test.cc b/xla/fp_util_test.cc index 8ea22d9d1602b..3eb7c54f919b0 100644 --- a/xla/fp_util_test.cc +++ b/xla/fp_util_test.cc @@ -119,76 +119,6 @@ class FP8E4M3DistanceTest : public ::testing::Test {}; using F8E4M3Types = ::testing::Types; TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types); -TEST(FPDistanceTest, F4E2M1FNDistance) { - // a & b are equal - EXPECT_EQ(CalculateDistanceInFloats( - tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(4.0)), - 0); - - // a & b have the same exponents - EXPECT_EQ(CalculateDistanceInFloats( - tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(6.0)), - 1); - - // a & b have different exponents - EXPECT_EQ(CalculateDistanceInFloats( - tsl::float4_e2m1fn(2.0), tsl::float4_e2m1fn(4.0)), - 2); - - // 1 from 0 in the positive direction - EXPECT_EQ(CalculateDistanceInFloats( - std::numeric_limits::denorm_min(), - tsl::float4_e2m1fn(0)), - 1); - - // 1 from 0 in the negative direction - EXPECT_EQ(CalculateDistanceInFloats( - -std::numeric_limits::denorm_min(), - tsl::float4_e2m1fn(0)), - 1); - - // a & b have different signs - EXPECT_EQ(CalculateDistanceInFloats( - -std::numeric_limits::denorm_min(), - std::numeric_limits::denorm_min()), - 2); - - // 1 non denorm from 0 in the positive direction - EXPECT_EQ(CalculateDistanceInFloats( - std::numeric_limits::min(), - tsl::float4_e2m1fn(0)), - 2); - - // 1 non denorm from 0 in the negative direction - EXPECT_EQ(CalculateDistanceInFloats( - -std::numeric_limits::min(), - tsl::float4_e2m1fn(0)), - 2); - - // a & b have different signs - EXPECT_EQ(CalculateDistanceInFloats( - -std::numeric_limits::min(), - std::numeric_limits::min()), - 4); -} - -TEST(FPDistanceTest, F8E8M0FNUDistance) { - // a & b are equal - EXPECT_EQ(CalculateDistanceInFloats( - tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(1.0)), - 0); - - // one step apart - EXPECT_EQ(CalculateDistanceInFloats( - tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(2.0)), - 1); - - // two steps apart - EXPECT_EQ(CalculateDistanceInFloats( - tsl::float8_e8m0fnu(0.5), tsl::float8_e8m0fnu(2.0)), - 2); -} - TEST(FPDistanceTest, F8E3M4Distance) { // a & b are equal EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), diff --git a/xla/hlo/builder/lib/math.cc b/xla/hlo/builder/lib/math.cc index 620e907f8cf11..f2a77df3d7dda 100644 --- a/xla/hlo/builder/lib/math.cc +++ b/xla/hlo/builder/lib/math.cc @@ -184,7 +184,6 @@ XlaOp IsNegZero(XlaOp operand) { case F32: return Eq(BitcastConvertType(operand, U32), ConstantR0WithType(&b, U32, uint32_t{1} << 31)); - case F4E2M1FN: case F8E3M4: case F8E4M3: case F8E5M2: @@ -972,9 +971,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a)); PrimitiveType a_x_type = a_shape.element_type(); bool needs_upcast = false; - for (PrimitiveType type : - {BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN, - F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) { + for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; @@ -1026,9 +1024,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { } TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a)); bool needs_upcast = false; - for (PrimitiveType type : - {BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN, - F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) { + for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; diff --git a/xla/hlo/builder/lib/math_test.cc b/xla/hlo/builder/lib/math_test.cc index 126ba14f5bb39..9755643b7586a 100644 --- a/xla/hlo/builder/lib/math_test.cc +++ b/xla/hlo/builder/lib/math_test.cc @@ -95,13 +95,9 @@ class MathTypedTest : public MathTest { Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)}); bool has_inf = std::numeric_limits::has_infinity; - bool has_nan = std::numeric_limits::has_quiet_NaN; - bool has_finite = !has_inf && !has_nan; - bool has_nan_only = !has_inf && has_nan; - auto expected = LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR1({true, true, true, true, true, has_finite, - has_finite, has_finite, has_finite}), + LiteralUtil::CreateR1( + {true, true, true, true, true, false, false, false, false}), LiteralUtil::CreateR1({false, false, false, false, false, has_inf, has_inf, false, false}), LiteralUtil::CreateR1( @@ -109,8 +105,7 @@ class MathTypedTest : public MathTest { LiteralUtil::CreateR1( {false, false, false, false, false, false, has_inf, false, false}), LiteralUtil::CreateR1({false, false, false, false, false, - has_nan_only, has_nan_only, has_nan, - has_nan})); + !has_inf, !has_inf, true, true})); ComputeAndCompareLiteral(&b, expected, {}); } @@ -123,11 +118,10 @@ class MathTypedTest : public MathTest { LiteralUtil::CreateR1({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}), &b)); - bool is_mx = std::is_same_v; ComputeAndCompareLiteral( &b, LiteralUtil::CreateR1( - {has_negative_zero_v, false, false, false, false, false, is_mx}), + {has_negative_zero_v, false, false, false, false, false, false}), {}, error_spec_); } @@ -142,9 +136,6 @@ class MathTypedTest : public MathTest { // For good measure, we also check pow with an exponent other than 0.5. void TestSqrtPowInequivalence() { SetFastMathDisabled(true); - if (std::is_same_v) { - GTEST_SKIP() << "Skipping due to low precision"; - } // Tests disable constant folding by default, but this test needs it // enabled, otherwise we don't tickle the bug we're trying to catch. @@ -190,14 +181,9 @@ class MathTypedTest : public MathTest { &b); Erf(x); - bool inf_as_nan = !std::numeric_limits::has_infinity && - std::numeric_limits::has_quiet_NaN; - std::vector expected = {inf_as_nan ? nan : T(-1), - inf_as_nan ? nan : T(1), - T(-0), - T(0), - T(-1), - T(1)}; + bool has_inf = std::numeric_limits::has_infinity; + std::vector expected = { + has_inf ? T(-1) : nan, has_inf ? T(1) : nan, T(-0), T(0), T(-1), T(1)}; ComputeAndCompareR1(&b, expected, {}, error_spec_); } @@ -215,10 +201,6 @@ using TestTypes = #endif #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64 double, -#endif -#ifndef XLA_TEST_BACKEND_TPU - // TODO(b/385004399): Run tests on these types on TPU. - tsl::float4_e2m1fn, #endif float>; diff --git a/xla/hlo/evaluator/BUILD b/xla/hlo/evaluator/BUILD index d4b5a37667de0..ef2a6b7fb7603 100644 --- a/xla/hlo/evaluator/BUILD +++ b/xla/hlo/evaluator/BUILD @@ -37,7 +37,6 @@ cc_library( "hlo_evaluator_typed_visitor_int4.cc", "hlo_evaluator_typed_visitor_int64.cc", "hlo_evaluator_typed_visitor_int8.cc", - "hlo_evaluator_typed_visitor_mxfloat.cc", "hlo_evaluator_typed_visitor_uint16.cc", "hlo_evaluator_typed_visitor_uint32.cc", "hlo_evaluator_typed_visitor_uint64.cc", diff --git a/xla/hlo/evaluator/hlo_evaluator.cc b/xla/hlo/evaluator/hlo_evaluator.cc index 8e44243823c09..35fac878f104d 100644 --- a/xla/hlo/evaluator/hlo_evaluator.cc +++ b/xla/hlo/evaluator/hlo_evaluator.cc @@ -3722,7 +3722,7 @@ absl::StatusOr StochasticConvertOp(const Literal& operand_literal, const Shape& result_shape) { std::function stochastic_convert_op = [](Fp operand, Uint random) -> ResultT { - bool is_negative = static_cast(SignAndMagnitude(operand).first); + bool is_negative = static_cast(Eigen::numext::signbit(operand)); if (Eigen::numext::isinf(operand)) { return is_negative ? std::numeric_limits::min() : std::numeric_limits::max(); diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 8499b0ab7107d..74feab55e5e9c 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1736,7 +1736,6 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; -extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; @@ -1744,7 +1743,6 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; -extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc deleted file mode 100644 index 6bc96c1a1f7cd..0000000000000 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/hlo/evaluator/hlo_evaluator.h" -#include "xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" -#include "tsl/platform/ml_dtypes.h" - -namespace xla { -template class HloEvaluatorTypedVisitor; -template class HloEvaluatorTypedVisitor; -} // namespace xla diff --git a/xla/hlo/transforms/expanders/comparison_expander.cc b/xla/hlo/transforms/expanders/comparison_expander.cc index 86d1eeafcd593..0f09ecced1eba 100644 --- a/xla/hlo/transforms/expanders/comparison_expander.cc +++ b/xla/hlo/transforms/expanders/comparison_expander.cc @@ -115,41 +115,34 @@ absl::StatusOr ComparisonExpander::ExpandInstruction( ShapeUtil::ChangeElementType(rhs->shape(), compare_type), rhs)); } - if (compare_type != F8E8M0FNU) { - int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type()); - PrimitiveType signed_type = - primitive_util::SignedIntegralTypeForBitWidth(bit_width); - auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type); - - auto zero_value = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); - zero_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, zero_value, {})); - - auto min_value = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MinValue(signed_type))); - min_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, min_value, {})); - - auto max_value = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); - max_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, max_value, {})); - - lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value, - min_value, max_value); - rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value, - min_value, max_value); - } else { - auto int8_shape = ShapeUtil::ChangeElementType(lhs->shape(), U8); - lhs = computation->AddInstruction( - HloInstruction::CreateBitcastConvert(int8_shape, lhs)); - rhs = computation->AddInstruction( - HloInstruction::CreateBitcastConvert(int8_shape, rhs)); - } + int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type()); + PrimitiveType signed_type = + primitive_util::SignedIntegralTypeForBitWidth(bit_width); + auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type); + + auto zero_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); + zero_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, zero_value, {})); + + auto min_value = computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MinValue(signed_shape.element_type()))); + min_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, min_value, {})); + + auto max_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); + max_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, max_value, {})); + + lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value, + min_value, max_value); + rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value, + min_value, max_value); auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare( - instruction->shape(), lhs, rhs, compare->direction())); + instruction->shape(), lhs, rhs, compare->direction(), + Comparison::Type::kSigned)); VLOG(2) << "New comparison instruction for total order:" << new_compare->ToString(); diff --git a/xla/hlo/transforms/simplifiers/float_normalization.cc b/xla/hlo/transforms/simplifiers/float_normalization.cc index cf978bf581fcd..88dbd2781ca60 100644 --- a/xla/hlo/transforms/simplifiers/float_normalization.cc +++ b/xla/hlo/transforms/simplifiers/float_normalization.cc @@ -217,9 +217,6 @@ absl::Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( hlo->mutable_shape(), [&](Shape* subshape, const xla::ShapeIndex& index) { if (subshape->element_type() == from) { subshape->set_element_type(to); - if (subshape->has_layout() && from == F4E2M1FN) { - subshape->mutable_layout()->set_element_size_in_bits(0); - } } }); float_normalization_->UpdateLayout(hlo->mutable_shape()); diff --git a/xla/hlo/transforms/simplifiers/float_normalization_test.cc b/xla/hlo/transforms/simplifiers/float_normalization_test.cc index b614f74229c0e..86ec889abc652 100644 --- a/xla/hlo/transforms/simplifiers/float_normalization_test.cc +++ b/xla/hlo/transforms/simplifiers/float_normalization_test.cc @@ -150,9 +150,7 @@ class FloatNormalizationF8Test public ::testing::WithParamInterface {}; INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test, - ::testing::Values(F4E2M1FN, F8E3M4, F8E4M3, - F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, - F8E5M2, F8E5M2FNUZ, F8E8M0FNU)); + ::testing::Values(F8E3M4, F8E4M3, F8E5M2)); TEST_F(FloatNormalizationTest, NoopIfSupported) { auto builder = HloComputation::Builder(TestName()); diff --git a/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc b/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc index cea1bc583ea56..f70769ea91abe 100644 --- a/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc +++ b/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" @@ -70,25 +69,6 @@ ::mlir::DenseElementsAttr CreateDenseAttrFromLiteral( } return ::mlir::DenseElementsAttr::getFromRawBuffer(type, packed_padded_data); - } else if constexpr (std::is_same_v) { - // DenseElementsAttr::get() does not support being passed an array of - // tsl::float4_e2m1fn. So convert each element to APFloat first. - auto data_span = literal.data(); - std::vector apfloats; - apfloats.reserve(literal.element_count()); - for (size_t i = 0; i < literal.element_count(); i++) { - llvm::APFloat apfloat{static_cast(data_span[i])}; - bool losesInfo; - llvm::APFloat::opStatus status = - apfloat.convert(llvm::APFloat::Float4E2M1FN(), - llvm::APFloat::rmNearestTiesToEven, &losesInfo); - CHECK_EQ(status, llvm::APFloat::opOK) - << "Failed to convert " << data_span[i] << " to Float4E2M1FN APFloat"; - CHECK(!losesInfo) << "Lost info when converting " << data_span[i] - << " to Float4E2M1FN APFloat"; - apfloats.push_back(apfloat); - } - return ::mlir::DenseElementsAttr::get(type, apfloats); } else { auto data_span = literal.data(); return ::mlir::DenseElementsAttr::get( diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo index 577e4ad61f89e..3a1e7ceabb160 100644 --- a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo +++ b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo @@ -421,12 +421,6 @@ add { // CHECK: %[[VAL_13:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> %constant.13 = f8e3m4[4] constant({1, 2, 3, 4}) - - // CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN> - %constant.14 = f4e2m1fn[4] constant({1, 2, 3, 4}) - - // CHECK: %[[VAL_15:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00]> : tensor<4xf8E8M0FNU> - %constant.15 = f8e8m0fnu[4] constant({1, 2, 4, 8}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -548,19 +542,7 @@ add { %convert.15 = f8e3m4[4] convert(f32[4] %convert.14) // CHECK-NEXT: %13 = mhlo.convert %12 : (tensor<4xf8E3M4>) -> tensor<4xf32> - %convert.16 = f32[4] convert(f8e3m4[4] %convert.15) - - // CHECK-NEXT: %14 = mhlo.convert %13 : (tensor<4xf32>) -> tensor<4xf4E2M1FN> - %convert.17 = f4e2m1fn[4] convert(f32[4] %convert.16) - - // CHECK-NEXT: %15 = mhlo.convert %14 : (tensor<4xf4E2M1FN>) -> tensor<4xf32> - %convert.18 = f32[4] convert(f4e2m1fn[4] %convert.17) - - // CHECK-NEXT: %16 = mhlo.convert %15 : (tensor<4xf32>) -> tensor<4xf8E8M0FNU> - %convert.19 = f8e8m0fnu[4] convert(f32[4] %convert.18) - - // CHECK-NEXT: %17 = mhlo.convert %16 : (tensor<4xf8E8M0FNU>) -> tensor<4xf32> - ROOT %convert.20 = f32[4] convert(f8e8m0fnu[4] %convert.19) + ROOT %convert.16 = f32[4] convert(f8e3m4[4] %convert.15) } // CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8> diff --git a/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc b/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc index f50e2a097a327..821f1487cf88c 100644 --- a/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc +++ b/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc @@ -41,12 +41,6 @@ xla::Array ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) { xla::Array array(shape.dimensions()); if constexpr (!xla::primitive_util::IsSubByteNonPredType(type)) { array.SetValues(dense_attr.getValues()); - } else if constexpr (xla::primitive_util::IsMXType(type)) { - // Bitcast MX floating point types from APFloat. - auto values = dense_attr.getValues(); - for (int i = 0; i < values.size(); i++) { - array.data()[i] = T::FromRep(values[i].bitcastToAPInt().getZExtValue()); - } } else { // The only way to get subbyte integers from getValues() is to get them as // APInts. diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir index c017751477cb5..a22ec331d93b2 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir @@ -606,12 +606,6 @@ func.func @main() { // CHECK: f8e3m4[4] constant({1, 2, 3, 4}) %cst_17 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> - // CHECK: f4e2m1fn[4] constant({1, 2, 3, 4}) - %cst_18 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN> - - // CHECK: f8e8m0fnu[4] constant({1, 2, 4, 8}) - %cst_19 = arith.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00]> : tensor<4xf8E8M0FNU> - func.return } @@ -745,11 +739,7 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { %9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32> %10 = "mhlo.convert"(%9) : (tensor<2xf32>) -> tensor<2xf8E3M4> %11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32> - %12 = "mhlo.convert"(%11) : (tensor<2xf32>) -> tensor<2xf4E2M1FN> - %13 = "mhlo.convert"(%12) : (tensor<2xf4E2M1FN>) -> tensor<2xf32> - %14 = "mhlo.convert"(%13) : (tensor<2xf32>) -> tensor<2xf8E8M0FNU> - %15 = "mhlo.convert"(%14) : (tensor<2xf8E8M0FNU>) -> tensor<2xf32> - func.return %15 : tensor<2xf32> + func.return %11 : tensor<2xf32> } // CHECK: ENTRY @@ -765,11 +755,7 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]]) // CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]]) // CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]]) -// CHECK: %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) -// CHECK: %[[E2M1FN_VAL:.*]] = f4e2m1fn[2] convert(f32[2] %[[F32_VAL6]]) -// CHECK: %[[F32_VAL7:.*]] = f32[2] convert(f4e2m1fn[2] %[[E2M1FN_VAL]]) -// CHECK: %[[E8M0FNU_VAL:.*]] = f8e8m0fnu[2] convert(f32[2] %[[F32_VAL7]]) -// CHECK: ROOT %[[F32_VAL8:.*]] = f32[2] convert(f8e8m0fnu[2] %[[E8M0FNU_VAL]]) +// CHECK: ROOT %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) // ----- diff --git a/xla/literal.cc b/xla/literal.cc index 866bc1838a919..997f44a4dd0f6 100644 --- a/xla/literal.cc +++ b/xla/literal.cc @@ -91,11 +91,10 @@ bool LiteralProtoHasValues(const LiteralProto& proto) { !proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() || !proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() || !proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() || - !proto.f4e2m1fns().empty() || !proto.f8e3m4s().empty() || - !proto.f8e4m3b11fnuzs().empty() || !proto.f8e4m3fns().empty() || - !proto.f8e4m3fnuzs().empty() || !proto.f8e4m3s().empty() || - !proto.f8e5m2fnuzs().empty() || !proto.f8e5m2s().empty() || - !proto.f8e8m0fnus().empty() || !proto.f16s().empty() || + !proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() || + !proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() || + !proto.f8e5m2fnuzs().empty() || !proto.f8e4m3fnuzs().empty() || + !proto.f8e3m4s().empty() || !proto.f16s().empty() || !proto.bf16s().empty() || proto.f32s_size() || proto.f64s_size() || proto.c64s_size() || proto.c128s_size() || proto.preds_size() || proto.tuple_literals_size(); @@ -1875,6 +1874,7 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { << __func__ << " is only supported for dense arrays: " << subshape(); CHECK_EQ(size_bytes_dense(), other.size_bytes_dense()); if (primitive_util::IsSubByteNonPredType(subshape().element_type())) { + CHECK(!primitive_util::IsFloatingPointType(subshape().element_type())); auto one_array = buffer(); auto two_array = other.buffer(); const int bits_per_element = @@ -2259,11 +2259,6 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case S64: CopyToRepeatedField(proto->mutable_s64s(), data()); break; - case F4E2M1FN: - *proto->mutable_f4e2m1fns() = std::string( - reinterpret_cast(data().data()), - size_bytes_dense()); - break; case F8E5M2: *proto->mutable_f8e5m2s() = std::string( reinterpret_cast(data().data()), @@ -2299,11 +2294,6 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; - case F8E8M0FNU: - *proto->mutable_f8e8m0fnus() = std::string( - reinterpret_cast(data().data()), - size_bytes_dense()); - break; case F16: *proto->mutable_f16s() = std::string(reinterpret_cast(data().data()), @@ -2455,14 +2445,6 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { case U64: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); break; - case F4E2M1FN: { - const std::string& s(proto.f4e2m1fns()); - TF_RET_CHECK(data().size() * - sizeof(tsl::float4_e2m1fn) == - s.size()); - memcpy(untyped_data(), s.data(), s.size()); - break; - } case F8E5M2: { const std::string& s(proto.f8e5m2s()); TF_RET_CHECK(data().size() * sizeof(tsl::float8_e5m2) == @@ -2516,14 +2498,6 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { memcpy(untyped_data(), s.data(), s.size()); break; } - case F8E8M0FNU: { - const std::string& s(proto.f8e8m0fnus()); - TF_RET_CHECK(data().size() * - sizeof(tsl::float8_e8m0fnu) == - s.size()); - memcpy(untyped_data(), s.data(), s.size()); - break; - } case F16: { const std::string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/xla/literal.h b/xla/literal.h index db40cd7f65003..0c028bd1aa60e 100644 --- a/xla/literal.h +++ b/xla/literal.h @@ -589,17 +589,18 @@ class LiteralBase { primitive_util::NativeToPrimitiveType(); constexpr int bits_per_element = primitive_util::BitWidth(primitive_type); if constexpr (bits_per_element < 8) { + static_assert(!primitive_util::IsFloatingPointType(primitive_type)); static_assert(!primitive_util::IsComplexType(primitive_type)); static_assert(8 % bits_per_element == 0); - constexpr int elements_per_byte = 8 / bits_per_element; + int64_t bytes = elements.size() / elements_per_byte; for (int64_t i = 0; i < bytes; ++i) { uint8_t byte = 0; for (int b = 0; b < elements_per_byte; ++b) { - uint8_t src = Eigen::numext::bit_cast( - elements[i * elements_per_byte + b]) & - LsbMask(bits_per_element); + uint8_t src = + static_cast(elements[i * elements_per_byte + b]) & + LsbMask(bits_per_element); byte |= src << (b * bits_per_element); } WriteElement(byte); @@ -608,9 +609,9 @@ class LiteralBase { if (rest != 0) { uint8_t byte = 0; for (int64_t b = 0; b < rest; ++b) { - uint8_t src = Eigen::numext::bit_cast( - elements[bytes * elements_per_byte + b]) & - LsbMask(bits_per_element); + uint8_t src = + static_cast(elements[bytes * elements_per_byte + b]) & + LsbMask(bits_per_element); byte |= src << (b * bits_per_element); } WriteElement(byte); @@ -700,17 +701,11 @@ class LiteralBase { primitive_util::NativeToPrimitiveType(); constexpr int bits_per_element = primitive_util::BitWidth(primitive_type); if constexpr (bits_per_element < 8) { + static_assert(!primitive_util::IsFloatingPointType(primitive_type)); static_assert(!primitive_util::IsComplexType(primitive_type)); static_assert(8 % bits_per_element == 0); - - constexpr auto cast = [](uint8_t x) -> NativeT { - if constexpr (primitive_util::IsFloatingPointType(primitive_type)) { - return Eigen::numext::bit_cast(x); - } - return static_cast(x); - }; - constexpr int elements_per_byte = 8 / bits_per_element; + int64_t bytes = elements.size() / elements_per_byte; for (int64_t i = 0; i < bytes; ++i) { uint8_t byte; @@ -719,7 +714,7 @@ class LiteralBase { } for (int b = 0; b < elements_per_byte; ++b) { elements[i * elements_per_byte + b] = - cast(byte & LsbMask(bits_per_element)); + static_cast(byte & LsbMask(bits_per_element)); byte >>= bits_per_element; } } @@ -731,7 +726,7 @@ class LiteralBase { } for (int64_t b = 0; b < rest; ++b) { elements[bytes * elements_per_byte + b] = - cast(byte & LsbMask(bits_per_element)); + static_cast(byte & LsbMask(bits_per_element)); byte >>= bits_per_element; } } diff --git a/xla/literal_comparison.cc b/xla/literal_comparison.cc index ecea502496393..c97629594122b 100644 --- a/xla/literal_comparison.cc +++ b/xla/literal_comparison.cc @@ -206,8 +206,8 @@ template std::string FpValueToString(NativeT value) { if constexpr (is_specialized_floating_point_v) { constexpr int kPrecisionDigits = std::numeric_limits::max_digits10; - const int kExponentDigts = std::ceil( - std::log10(std::max(std::numeric_limits::max_exponent10, 1))); + const int kExponentDigts = + std::ceil(std::log10(std::numeric_limits::max_exponent10)); constexpr int kExtraChars = 4; const int kTotalChars = kPrecisionDigits * kExponentDigts + kExtraChars; return absl::StrFormat("%*.*g", kTotalChars, kPrecisionDigits, @@ -418,9 +418,6 @@ class NearComparator { } else { float_distance = CalculateFloatDistance(expected, actual); abs_error = FpAbsoluteValue(actual - expected); - if (!std::numeric_limits::is_signed && IsNaN(abs_error)) { - abs_error = FpAbsoluteValue(expected - actual); - } // Avoid division by 0 even though it's well-defined because ubsan can be // configured to treat this as a fatal error. diff --git a/xla/literal_comparison_test.cc b/xla/literal_comparison_test.cc index 29c12eb7c75e4..7713aceaaa3bc 100644 --- a/xla/literal_comparison_test.cc +++ b/xla/literal_comparison_test.cc @@ -30,15 +30,13 @@ template class LiteralComparisonTest : public ::testing::Test {}; using TestedTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes); TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { - auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); - auto expected = LiteralUtil::CreateR0(TypeParam(1.0)); + auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); + auto expected = LiteralUtil::CreateR0(TypeParam(8.0)); TF_EXPECT_OK(literal_comparison::Near(expected, actual, ErrorSpec(0.0, 0.0), /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); @@ -46,16 +44,12 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); - auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); - float expV = 1.125; // F8E4M3* - if (type == F8E5M2 || type == F8E5M2FNUZ) - expV = 1.25; + auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); + float expV = 9.0; // F8E4M3* + if (type == F8E5M2) + expV = 10.0; else if (type == F8E3M4) - expV = 1.0625; - else if (type == F4E2M1FN) - expV = 1.5; - else if (type == F8E8M0FNU) - expV = 2.0; + expV = 8.5; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, @@ -70,16 +64,12 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); - auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); - float expV = 1.5; // F8E4M3* - if (type == F8E5M2 || type == F8E5M2FNUZ) - expV = 2.0; + auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); + float expV = 12.0; // F8E4M3* + if (type == F8E5M2) + expV = 14.0; else if (type == F8E3M4) - expV = 1.25; - else if (type == F4E2M1FN) - expV = 4.0; - else if (type == F8E8M0FNU) - expV = 16.0; + expV = 10.0; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; @@ -96,16 +86,12 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); - auto actual = LiteralUtil::CreateR0(1.0); - float expV = 1.51; // F8E4M3* - if (type == F8E5M2 || type == F8E5M2FNUZ) - expV = 2.01; + auto actual = LiteralUtil::CreateR0(8.0); + float expV = 12.1; // F8E4M3* + if (type == F8E5M2) + expV = 13.0; else if (type == F8E3M4) - expV = 1.26; - else if (type == F4E2M1FN) - expV = 4.1; - else if (type == F8E8M0FNU) - expV = 16.5; + expV = 10.125; auto expected = LiteralUtil::CreateR0(expV); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; diff --git a/xla/literal_test.cc b/xla/literal_test.cc index 7aa9f2dc040dc..44e4acd6a5cef 100644 --- a/xla/literal_test.cc +++ b/xla/literal_test.cc @@ -124,11 +124,11 @@ class LiteralUtilTest : public ::testing::Test { template class LiteralUtilFloatTest : public LiteralUtilTest {}; -using FloatTypes = ::testing::Types; +using FloatTypes = + ::testing::Types; TYPED_TEST_SUITE(LiteralUtilFloatTest, FloatTypes); @@ -175,10 +175,6 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { LiteralUtil::CreateR0(static_cast(9.001f)); EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString()); - auto f4e2m1fn_lit = - LiteralUtil::CreateR0(tsl::float4_e2m1fn(0.5)); - EXPECT_EQ("f4e2m1fn[] 0.5", f4e2m1fn_lit.ToString()); - auto f8e5m2_lit = LiteralUtil::CreateR0(tsl::float8_e5m2(0.5)); EXPECT_EQ("f8e5m2[] 0.5", f8e5m2_lit.ToString()); @@ -211,10 +207,6 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f8e3m4_lit = LiteralUtil::CreateR0(tsl::float8_e3m4(0.5)); EXPECT_EQ("f8e3m4[] 0.5", f8e3m4_lit.ToString()); - - auto f8e8m0fnu_lit = - LiteralUtil::CreateR0(tsl::float8_e8m0fnu(0.5)); - EXPECT_EQ("f8e8m0fnu[] 0.5", f8e8m0fnu_lit.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -667,11 +659,6 @@ TEST_F(LiteralUtilTest, IsAll) { bfloat16 b90(9.00f); EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); - tsl::float4_e2m1fn m16(4); - EXPECT_TRUE(LiteralUtil::CreateR1({m16}).IsAll(4)); - // 5 rounds to 4 in E2M1FN but is not equal to 4, so this should be false - EXPECT_FALSE(LiteralUtil::CreateR1({m16}).IsAll(5)); - tsl::float8_e5m2 p16(8); EXPECT_TRUE(LiteralUtil::CreateR1({p16}).IsAll(8)); // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false @@ -702,11 +689,6 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(LiteralUtil::CreateR1({v16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({v16}).IsAll(9)); - tsl::float8_e8m0fnu w16(8); - EXPECT_TRUE(LiteralUtil::CreateR1({w16}).IsAll(8)); - // 9 rounds to 8 in E8M0FNU but is not equal to 8, so this should be false - EXPECT_FALSE(LiteralUtil::CreateR1({w16}).IsAll(9)); - complex64 c8_9 = {8, 9}; EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); @@ -2232,9 +2214,6 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); auto vector_half = LiteralUtil::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); - using e2m1 = tsl::float4_e2m1fn; - auto vector_f4e2m1fn = - LiteralUtil::CreateR1({e2m1{1.0}, e2m1{2.0}, e2m1{-3.0}}); using e5 = tsl::float8_e5m2; auto vector_f8e5m2 = LiteralUtil::CreateR1({e5{10.0}, e5{20.0}, e5{-32.0}}); @@ -2255,9 +2234,6 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { LiteralUtil::CreateR1({e4f{10.0}, e4f{20.0}, e4f{-30.0}}); using e3 = tsl::float8_e3m4; auto vector_f8e3m4 = LiteralUtil::CreateR1({e3{2.5}, e3{5.0}, e3{-8.0}}); - using e8m0 = tsl::float8_e8m0fnu; - auto vector_f8e8m0fnu = - LiteralUtil::CreateR1({e8m0{1.0}, e8m0{2.0}, e8m0{4.0}}); auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto vector_s4 = LiteralUtil::CreateR1({s4{-1}, s4{3}, s4{7}}); @@ -2278,15 +2254,13 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); EXPECT_EQ(vector_c128, to_from_proto(vector_c128)); EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); - EXPECT_EQ(vector_f4e2m1fn, to_from_proto(vector_f4e2m1fn)); - EXPECT_EQ(vector_f8e3m4, to_from_proto(vector_f8e3m4)); + EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); EXPECT_EQ(vector_f8e4m3, to_from_proto(vector_f8e4m3)); - EXPECT_EQ(vector_f8e4m3b11, to_from_proto(vector_f8e4m3b11)); EXPECT_EQ(vector_f8e4m3fn, to_from_proto(vector_f8e4m3fn)); - EXPECT_EQ(vector_f8e4m3fnuz, to_from_proto(vector_f8e4m3fnuz)); - EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); + EXPECT_EQ(vector_f8e4m3b11, to_from_proto(vector_f8e4m3b11)); EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); - EXPECT_EQ(vector_f8e8m0fnu, to_from_proto(vector_f8e8m0fnu)); + EXPECT_EQ(vector_f8e4m3fnuz, to_from_proto(vector_f8e4m3fnuz)); + EXPECT_EQ(vector_f8e3m4, to_from_proto(vector_f8e3m4)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(vector_s4, to_from_proto(vector_s4)); EXPECT_EQ(vector_u4, to_from_proto(vector_u4)); @@ -2537,19 +2511,19 @@ TEST_F(LiteralUtilTest, SliceOnBool) { } TEST_F(LiteralUtilTest, IsEqualAt) { - double val_double = 4.0; - int val_integral = 4; - Literal c1 = LiteralUtil::CreateR0(val_integral); + double val_double = 10.0; + int val_integral = 10; + Literal c1 = LiteralUtil::CreateR0(10); EXPECT_TRUE(c1.IsEqualAt({}, val_double)); EXPECT_TRUE(c1.IsEqualAt({}, val_integral)); - Literal c2 = LiteralUtil::CreateR0(val_double); + Literal c2 = LiteralUtil::CreateR0(10); EXPECT_TRUE(c2.IsEqualAt({}, val_double)); EXPECT_TRUE(c2.IsEqualAt({}, val_integral)); Literal c3 = LiteralUtil::CreateR0(tsl::float8_e5m2{val_double}); EXPECT_TRUE(c3.IsEqualAt({}, val_double)); EXPECT_TRUE(c3.IsEqualAt({}, val_integral)); - complex128 val_complex = {val_double, 0}; + complex128 val_complex = {10, 0}; EXPECT_TRUE(c1.IsEqualAt({}, val_complex)); EXPECT_TRUE(c2.IsEqualAt({}, val_complex)); EXPECT_TRUE(c3.IsEqualAt({}, val_complex)); @@ -2558,8 +2532,8 @@ TEST_F(LiteralUtilTest, IsEqualAt) { EXPECT_TRUE(c4.IsEqualAt({}, val_integral)); EXPECT_TRUE(c4.IsEqualAt({}, val_complex)); EXPECT_FALSE(c4.IsEqualAt({}, std::numeric_limits::infinity())); - complex128 val_true_complex = {val_double, 3}; - complex64 val_smaller_complex = {static_cast(val_double), 3}; + complex128 val_true_complex = {10, 3}; + complex64 val_smaller_complex = {10, 3}; Literal c5 = LiteralUtil::CreateR0(val_true_complex); EXPECT_TRUE(c5.IsEqualAt({}, val_true_complex)); EXPECT_TRUE(c5.IsEqualAt({}, val_smaller_complex)); @@ -2583,14 +2557,6 @@ TEST_F(LiteralUtilTest, IsEqualAt) { LiteralUtil::CreateR0(tsl::float8_e3m4{val_double}); EXPECT_TRUE(c10.IsEqualAt({}, val_double)); EXPECT_TRUE(c10.IsEqualAt({}, val_integral)); - Literal c11 = - LiteralUtil::CreateR0(tsl::float4_e2m1fn{val_double}); - EXPECT_TRUE(c11.IsEqualAt({}, val_double)); - EXPECT_TRUE(c11.IsEqualAt({}, val_integral)); - Literal c12 = LiteralUtil::CreateR0( - tsl::float8_e8m0fnu{val_double}); - EXPECT_TRUE(c12.IsEqualAt({}, val_double)); - EXPECT_TRUE(c12.IsEqualAt({}, val_integral)); } TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { @@ -2916,11 +2882,10 @@ class LiteralSerializationTest : public ::testing::Test, static std::vector GenerateSimpleParams() { std::vector params; for (PrimitiveType element_type : - {PRED, S4, U4, S8, U8, S16, - U16, S32, U32, S64, U64, F16, - F32, F64, BF16, F4E2M1FN, F8E3M4, F8E4M3, - F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F8E8M0FNU, - C64, C128}) { + {PRED, S4, U4, S8, U8, S16, + U16, S32, U32, S64, U64, F16, + F32, F64, BF16, F8E5M2, F8E4M3, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ, F8E3M4, C64, C128}) { for (const DimensionVector& dimensions : { DimensionVector{}, DimensionVector{0}, diff --git a/xla/mlir/utils/type_util.cc b/xla/mlir/utils/type_util.cc index ea8da4d4990d9..2581390a1e13d 100644 --- a/xla/mlir/utils/type_util.cc +++ b/xla/mlir/utils/type_util.cc @@ -32,8 +32,6 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( switch (type) { case xla::PrimitiveType::PRED: return b.getI1Type(); - case xla::PrimitiveType::F4E2M1FN: - return b.getFloat4E2M1FNType(); case xla::PrimitiveType::F8E5M2: return b.getFloat8E5M2Type(); case xla::PrimitiveType::F8E4M3: @@ -48,8 +46,6 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( return b.getFloat8E4M3FNUZType(); case xla::PrimitiveType::F8E3M4: return b.getFloat8E3M4Type(); - case xla::PrimitiveType::F8E8M0FNU: - return b.getFloat8E8M0FNUType(); case xla::PrimitiveType::F16: return b.getF16Type(); case xla::PrimitiveType::BF16: @@ -82,9 +78,7 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( } xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { - if (type.isFloat4E2M1FN()) { - return xla::PrimitiveType::F4E2M1FN; - } else if (type.isFloat8E5M2()) { + if (type.isFloat8E5M2()) { return xla::PrimitiveType::F8E5M2; } else if (type.isFloat8E4M3()) { return xla::PrimitiveType::F8E4M3; @@ -98,8 +92,6 @@ xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { return xla::PrimitiveType::F8E5M2FNUZ; } else if (type.isFloat8E3M4()) { return xla::PrimitiveType::F8E3M4; - } else if (type.isFloat8E8M0FNU()) { - return xla::PrimitiveType::F8E8M0FNU; } else if (type.isBF16()) { return xla::PrimitiveType::BF16; } else if (type.isF16()) { diff --git a/xla/mlir/utils/type_util_test.cc b/xla/mlir/utils/type_util_test.cc index 2239943d906b7..a8043ab0b5f14 100644 --- a/xla/mlir/utils/type_util_test.cc +++ b/xla/mlir/utils/type_util_test.cc @@ -101,7 +101,6 @@ INSTANTIATE_TEST_SUITE_P( Execute, TypeUtilTest, ::testing::ValuesIn(std::vector( {{PRED, [](mlir::Builder b) { return b.getI1Type(); }}, - {F4E2M1FN, [](mlir::Builder b) { return b.getFloat4E2M1FNType(); }}, {F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }}, {F8E4M3, [](mlir::Builder b) { return b.getFloat8E4M3Type(); }}, {F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }}, @@ -112,7 +111,6 @@ INSTANTIATE_TEST_SUITE_P( {F8E4M3FNUZ, [](mlir::Builder b) { return b.getFloat8E4M3FNUZType(); }}, {F8E3M4, [](mlir::Builder b) { return b.getFloat8E3M4Type(); }}, - {F8E8M0FNU, [](mlir::Builder b) { return b.getFloat8E8M0FNUType(); }}, {F16, [](mlir::Builder b) { return b.getF16Type(); }}, {BF16, [](mlir::Builder b) { return b.getBF16Type(); }}, {F32, [](mlir::Builder b) { return b.getF32Type(); }}, diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 16a64cdc22b76..d07a178c6c4e7 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -6844,13 +6844,6 @@ func.func @invalid_dimension_attr(%arg0: tensor) -> tensor { - %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - func.func @f8e3m4(%arg0: tensor) -> tensor { %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor func.return %0 : tensor @@ -6879,13 +6872,6 @@ func.func @f8e5m2(%arg0: tensor) -> tensor { // ----- -func.func @f8e8m0fnu(%arg0: tensor) -> tensor { - %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - func.func @top_k_1d(%arg0 : tensor<16xf32>) { %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) return diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index fe9158f41e337..5852c9a54dcc0 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,7 +1,4 @@ # PJRT C API changelog -## 0.61 -* Added types F4E2M1FN and F8E8M0FNU. - ## 0.60 * Added ``PJRT_Client_CreateBuffersForAsyncHostToDevice`` and ``PJRT_AsyncHostToDeviceTransferManager_TransferRawDataToSubBuffer``. diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index 61a1f8785bc58..36d82b0787ba4 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 61 +#define PJRT_API_MINOR 60 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -681,10 +681,6 @@ typedef enum { // More truncated 8 bit floating-point formats. PJRT_Buffer_Type_F8E4M3, PJRT_Buffer_Type_F8E3M4, - PJRT_Buffer_Type_F8E8M0FNU, - - // 4-bit MX floating-point format. - PJRT_Buffer_Type_F4E2M1FN, } PJRT_Buffer_Type; typedef enum { diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index b1ad44329a40e..2060a73a634a4 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -310,8 +310,6 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_BF16; case xla::PrimitiveType::F64: return PJRT_Buffer_Type::PJRT_Buffer_Type_F64; - case xla::PrimitiveType::F4E2M1FN: - return PJRT_Buffer_Type::PJRT_Buffer_Type_F4E2M1FN; case xla::PrimitiveType::F8E5M2: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2; case xla::PrimitiveType::F8E4M3: @@ -326,8 +324,6 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ; case xla::PrimitiveType::F8E3M4: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4; - case xla::PrimitiveType::F8E8M0FNU: - return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E8M0FNU; case xla::PrimitiveType::C64: return PJRT_Buffer_Type::PJRT_Buffer_Type_C64; case xla::PrimitiveType::C128: @@ -381,8 +377,6 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::C64; case PJRT_Buffer_Type::PJRT_Buffer_Type_C128: return xla::PrimitiveType::C128; - case PJRT_Buffer_Type::PJRT_Buffer_Type_F4E2M1FN: - return xla::PrimitiveType::F4E2M1FN; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2: return xla::PrimitiveType::F8E5M2; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3: @@ -397,8 +391,6 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::F8E4M3FNUZ; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4: return xla::PrimitiveType::F8E3M4; - case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E8M0FNU: - return xla::PrimitiveType::F8E8M0FNU; case PJRT_Buffer_Type::PJRT_Buffer_Type_INVALID: CHECK(false) << "Buffer type is not supported in C API layer."; } diff --git a/xla/primitive_util.cc b/xla/primitive_util.cc index 5006406ea9977..b70ba275a1f47 100644 --- a/xla/primitive_util.cc +++ b/xla/primitive_util.cc @@ -93,18 +93,6 @@ bool HasInfinity(PrimitiveType type) { return false; } -bool HasNaN(PrimitiveType type) { - if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { - return FloatingPointTypeSwitch( - [&](auto constant_type) -> bool { - return std::numeric_limits< - NativeTypeOf>::has_quiet_NaN; - }, - type); - } - return false; -} - bool HasNegativeZero(PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { return FloatingPointTypeSwitch( diff --git a/xla/primitive_util.h b/xla/primitive_util.h index 70a8335c8bc51..b9c1c978bc620 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h @@ -69,9 +69,6 @@ int ExponentBias(PrimitiveType type); // Returns whether the type has a value for infinity. bool HasInfinity(PrimitiveType type); -// Returns whether the type has a value for NaN. -bool HasNaN(PrimitiveType type); - // Returns whether the type has a value for negative zero. bool HasNegativeZero(PrimitiveType type); @@ -188,11 +185,6 @@ constexpr PrimitiveType NativeToPrimitiveType() { return BF16; } -template <> -constexpr PrimitiveType NativeToPrimitiveType() { - return F4E2M1FN; -} - template <> constexpr PrimitiveType NativeToPrimitiveType() { return F8E5M2; @@ -228,11 +220,6 @@ constexpr PrimitiveType NativeToPrimitiveType() { return F8E3M4; } -template <> -constexpr PrimitiveType NativeToPrimitiveType() { - return F8E8M0FNU; -} - // Complex template <> constexpr PrimitiveType NativeToPrimitiveType() { @@ -347,11 +334,6 @@ struct PrimitiveTypeToNative { using type = bfloat16; }; -template <> -struct PrimitiveTypeToNative { - using type = tsl::float4_e2m1fn; -}; - template <> struct PrimitiveTypeToNative { using type = tsl::float8_e5m2; @@ -387,11 +369,6 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e3m4; }; -template <> -struct PrimitiveTypeToNative { - using type = tsl::float8_e8m0fnu; -}; - // Complex template <> struct PrimitiveTypeToNative { @@ -424,10 +401,6 @@ inline constexpr bool IsArrayType(PrimitiveType primitive_type) { primitive_type < PrimitiveType_ARRAYSIZE; } -constexpr bool IsMXType(PrimitiveType type) { - return type == F4E2M1FN || type == F8E8M0FNU; -} - constexpr bool IsF8Type(PrimitiveType type) { return type == F8E5M2 || type == F8E4M3 || type == F8E4M3FN || type == F8E4M3B11FNUZ || type == F8E5M2FNUZ || type == F8E4M3FNUZ || @@ -436,7 +409,7 @@ constexpr bool IsF8Type(PrimitiveType type) { constexpr bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64 || type == BF16 || - IsF8Type(type) || IsMXType(type); + IsF8Type(type); } constexpr bool IsComplexType(PrimitiveType type) { @@ -500,9 +473,6 @@ template constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { switch (type) { - case F4E2M1FN: - return std::forward(f)( - PrimitiveTypeConstant()); case F8E3M4: return std::forward(f)( PrimitiveTypeConstant()); @@ -524,9 +494,6 @@ constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { case F8E5M2FNUZ: return std::forward(f)( PrimitiveTypeConstant()); - case F8E8M0FNU: - return std::forward(f)( - PrimitiveTypeConstant()); case F16: return std::forward(f)(PrimitiveTypeConstant()); case BF16: @@ -610,9 +577,6 @@ inline constexpr int PrimitiveTypeBitWidth() { if constexpr (primitive_type == PRED) { return std::numeric_limits::digits; } - if constexpr (IsMXType(primitive_type)) { - return NativeT::kBits; - } if constexpr (IsFloatingPointType(primitive_type)) { return sizeof(NativeT) * std::numeric_limits::digits; } @@ -751,10 +715,6 @@ inline bool CastPreservesValues(PrimitiveType from_type, if (from_type == to_type) { return true; } - // * -> F8E8M0FNU is not possible because zero cannot be represented. - if (to_type == F8E8M0FNU) { - return false; - } // PRED -> * if (from_type == PRED) { return true; @@ -777,33 +737,21 @@ inline bool CastPreservesValues(PrimitiveType from_type, return false; } // F -> F is safe if the exponent/significand are preserved and `to_type` - // preserves infinities/nans/unsigned zero in `from_type`. + // preserves infinities in `from_type. if (primitive_util::IsFloatingPointType(from_type) && primitive_util::IsFloatingPointType(to_type)) { - return - // Target mantissa should be large enough. - primitive_util::SignificandWidth(from_type) <= - primitive_util::SignificandWidth(to_type) && - // Target exponent should be large enough. - primitive_util::ExponentWidth(from_type) <= - primitive_util::ExponentWidth(to_type) && - // HasInfinity check. - (!primitive_util::HasInfinity(from_type) || - primitive_util::HasInfinity(to_type)) && - // HasNaN check. - (!primitive_util::HasNaN(from_type) || - primitive_util::HasNaN(to_type)) && - // HasNegativeZero check. - (!primitive_util::HasNegativeZero(from_type) || - primitive_util::HasNegativeZero(to_type)) && - // Minimum denormal should be representable by target type. - (primitive_util::UnderflowExponent(from_type) - - primitive_util::SignificandWidth(from_type)) >= - (primitive_util::UnderflowExponent(to_type) - - primitive_util::SignificandWidth(to_type)) && - // Maximum exponent may be larger with custom bias (e.g. F8E4M3B11FNUZ). - primitive_util::OverflowExponent(from_type) <= - primitive_util::OverflowExponent(to_type); + return (!primitive_util::HasInfinity(from_type) || + primitive_util::HasInfinity(to_type)) && + primitive_util::SignificandWidth(from_type) <= + primitive_util::SignificandWidth(to_type) && + primitive_util::ExponentWidth(from_type) <= + primitive_util::ExponentWidth(to_type) && + (primitive_util::UnderflowExponent(from_type) - + primitive_util::SignificandWidth(from_type)) >= + (primitive_util::UnderflowExponent(to_type) - + primitive_util::SignificandWidth(to_type)) && + primitive_util::OverflowExponent(from_type) <= + primitive_util::OverflowExponent(to_type); } // F -> I is not safe because it drops fractional numbers. if (!primitive_util::IsIntegralType(from_type)) { diff --git a/xla/primitive_util_test.cc b/xla/primitive_util_test.cc index 68fad70096812..190e6442d0326 100644 --- a/xla/primitive_util_test.cc +++ b/xla/primitive_util_test.cc @@ -69,9 +69,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[PRED][F8E4M3] = expecteds[PRED][F8E4M3FN] = true; expecteds[PRED][F8E4M3B11FNUZ] = expecteds[PRED][F8E5M2FNUZ] = true; expecteds[PRED][F8E4M3FNUZ] = expecteds[PRED][F8E3M4] = true; - expecteds[PRED][F4E2M1FN] = true; - expecteds[PRED][F8E8M0FNU] = false; expecteds[S1][PRED] = false; + expecteds[S2][PRED] = false; expecteds[S1][S1] = true; expecteds[S1][S2] = true; expecteds[S1][S4] = true; @@ -92,7 +91,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S1][C64] = true; expecteds[S1][BF16] = true; expecteds[S1][C128] = true; - expecteds[S1][F4E2M1FN] = true; expecteds[S1][F8E5M2] = true; expecteds[S1][F8E4M3] = true; expecteds[S1][F8E4M3FN] = true; @@ -100,11 +98,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S1][F8E5M2FNUZ] = true; expecteds[S1][F8E4M3FNUZ] = true; expecteds[S1][F8E3M4] = true; - expecteds[S1][F8E8M0FNU] = false; - expecteds[S2][PRED] = false; expecteds[S2][S1] = false; - expecteds[S2][S2] = true; - expecteds[S2][S4] = true; + expecteds[S2][S2] = expecteds[S2][S4] = true; expecteds[S2][S8] = true; expecteds[S2][S16] = true; expecteds[S2][S32] = true; @@ -122,7 +117,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][C64] = true; expecteds[S2][BF16] = true; expecteds[S2][C128] = true; - expecteds[S2][F4E2M1FN] = true; expecteds[S2][F8E5M2] = true; expecteds[S2][F8E4M3] = true; expecteds[S2][F8E4M3FN] = true; @@ -130,7 +124,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][F8E5M2FNUZ] = true; expecteds[S2][F8E4M3FNUZ] = true; expecteds[S2][F8E3M4] = true; - expecteds[S2][F8E8M0FNU] = false; expecteds[S4][PRED] = false; expecteds[S4][S1] = false; expecteds[S4][S2] = false; @@ -152,7 +145,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][C64] = true; expecteds[S4][BF16] = true; expecteds[S4][C128] = true; - expecteds[S4][F4E2M1FN] = false; expecteds[S4][F8E5M2] = true; expecteds[S4][F8E4M3] = true; expecteds[S4][F8E4M3FN] = true; @@ -160,7 +152,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][F8E5M2FNUZ] = true; expecteds[S4][F8E4M3FNUZ] = true; expecteds[S4][F8E3M4] = true; - expecteds[S4][F8E8M0FNU] = false; expecteds[S8][PRED] = false; expecteds[S8][S1] = false; expecteds[S8][S2] = false; @@ -182,7 +173,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][C64] = true; expecteds[S8][BF16] = true; expecteds[S8][C128] = true; - expecteds[S8][F4E2M1FN] = false; expecteds[S8][F8E5M2] = false; expecteds[S8][F8E4M3] = false; expecteds[S8][F8E4M3FN] = false; @@ -190,7 +180,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][F8E5M2FNUZ] = false; expecteds[S8][F8E4M3FNUZ] = false; expecteds[S8][F8E3M4] = false; - expecteds[S8][F8E8M0FNU] = false; expecteds[S16][PRED] = false; expecteds[S16][S1] = false; expecteds[S16][S2] = false; @@ -212,7 +201,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][C64] = true; expecteds[S16][BF16] = false; expecteds[S16][C128] = true; - expecteds[S16][F4E2M1FN] = false; expecteds[S16][F8E5M2] = false; expecteds[S16][F8E4M3] = false; expecteds[S16][F8E4M3FN] = false; @@ -220,7 +208,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][F8E5M2FNUZ] = false; expecteds[S16][F8E4M3FNUZ] = false; expecteds[S16][F8E3M4] = false; - expecteds[S16][F8E8M0FNU] = false; expecteds[S32][PRED] = false; expecteds[S32][S1] = false; expecteds[S32][S2] = false; @@ -242,7 +229,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][C64] = false; expecteds[S32][BF16] = false; expecteds[S32][C128] = true; - expecteds[S32][F4E2M1FN] = false; expecteds[S32][F8E5M2] = false; expecteds[S32][F8E4M3] = false; expecteds[S32][F8E4M3FN] = false; @@ -250,7 +236,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][F8E5M2FNUZ] = false; expecteds[S32][F8E4M3FNUZ] = false; expecteds[S32][F8E3M4] = false; - expecteds[S32][F8E8M0FNU] = false; expecteds[S64][PRED] = false; expecteds[S64][S1] = false; expecteds[S64][S2] = false; @@ -272,7 +257,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][C64] = false; expecteds[S64][BF16] = false; expecteds[S64][C128] = false; - expecteds[S64][F4E2M1FN] = false; expecteds[S64][F8E5M2] = false; expecteds[S64][F8E4M3] = false; expecteds[S64][F8E4M3FN] = false; @@ -280,7 +264,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][F8E5M2FNUZ] = false; expecteds[S64][F8E4M3FNUZ] = false; expecteds[S64][F8E3M4] = false; - expecteds[S64][F8E8M0FNU] = false; expecteds[U1][PRED] = false; expecteds[U1][S1] = false; expecteds[U1][S2] = true; @@ -302,7 +285,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U1][C64] = true; expecteds[U1][BF16] = true; expecteds[U1][C128] = true; - expecteds[U1][F4E2M1FN] = true; + expecteds[U1][BF16] = true; + expecteds[U1][C128] = true; expecteds[U1][F8E5M2] = true; expecteds[U1][F8E4M3] = true; expecteds[U1][F8E4M3FN] = true; @@ -310,16 +294,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U1][F8E5M2FNUZ] = true; expecteds[U1][F8E4M3FNUZ] = true; expecteds[U1][F8E3M4] = true; - expecteds[U1][F8E8M0FNU] = false; expecteds[U2][PRED] = false; - expecteds[U2][S1] = false; + expecteds[U2][U1] = expecteds[U2][S1] = false; expecteds[U2][S2] = false; expecteds[U2][S4] = true; expecteds[U2][S8] = true; expecteds[U2][S16] = true; expecteds[U2][S32] = true; expecteds[U2][S64] = true; - expecteds[U2][U1] = false; expecteds[U2][U2] = true; expecteds[U2][U4] = true; expecteds[U2][U8] = true; @@ -332,7 +314,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][C64] = true; expecteds[U2][BF16] = true; expecteds[U2][C128] = true; - expecteds[U2][F4E2M1FN] = true; + expecteds[U2][BF16] = true; + expecteds[U2][C128] = true; expecteds[U2][F8E5M2] = true; expecteds[U2][F8E4M3] = true; expecteds[U2][F8E4M3FN] = true; @@ -340,7 +323,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][F8E5M2FNUZ] = true; expecteds[U2][F8E4M3FNUZ] = true; expecteds[U2][F8E3M4] = true; - expecteds[U2][F8E8M0FNU] = false; expecteds[U4][PRED] = false; expecteds[U4][S1] = false; expecteds[U4][S2] = false; @@ -362,7 +344,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][C64] = true; expecteds[U4][BF16] = true; expecteds[U4][C128] = true; - expecteds[U4][F4E2M1FN] = false; + expecteds[U4][BF16] = true; + expecteds[U4][C128] = true; expecteds[U4][F8E5M2] = false; expecteds[U4][F8E4M3] = true; expecteds[U4][F8E4M3FN] = true; @@ -370,7 +353,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][F8E5M2FNUZ] = false; expecteds[U4][F8E4M3FNUZ] = true; expecteds[U4][F8E3M4] = true; - expecteds[U4][F8E8M0FNU] = false; expecteds[U8][PRED] = false; expecteds[U8][S1] = false; expecteds[U8][S2] = false; @@ -392,7 +374,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][C64] = true; expecteds[U8][BF16] = true; expecteds[U8][C128] = true; - expecteds[U8][F4E2M1FN] = false; + expecteds[U8][BF16] = true; + expecteds[U8][C128] = true; expecteds[U8][F8E5M2] = false; expecteds[U8][F8E4M3] = false; expecteds[U8][F8E4M3FN] = false; @@ -400,7 +383,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][F8E5M2FNUZ] = false; expecteds[U8][F8E4M3FNUZ] = false; expecteds[U8][F8E3M4] = false; - expecteds[U8][F8E8M0FNU] = false; expecteds[U16][PRED] = false; expecteds[U16][S1] = false; expecteds[U16][S2] = false; @@ -422,7 +404,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][C64] = true; expecteds[U16][BF16] = false; expecteds[U16][C128] = true; - expecteds[U16][F4E2M1FN] = false; expecteds[U16][F8E5M2] = false; expecteds[U16][F8E4M3] = false; expecteds[U16][F8E4M3FN] = false; @@ -430,7 +411,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][F8E5M2FNUZ] = false; expecteds[U16][F8E4M3FNUZ] = false; expecteds[U16][F8E3M4] = false; - expecteds[U16][F8E8M0FNU] = false; expecteds[U32][PRED] = false; expecteds[U32][S1] = false; expecteds[U32][S2] = false; @@ -452,7 +432,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][C64] = false; expecteds[U32][BF16] = false; expecteds[U32][C128] = true; - expecteds[U32][F4E2M1FN] = false; expecteds[U32][F8E5M2] = false; expecteds[U32][F8E4M3] = false; expecteds[U32][F8E4M3FN] = false; @@ -460,7 +439,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][F8E5M2FNUZ] = false; expecteds[U32][F8E4M3FNUZ] = false; expecteds[U32][F8E3M4] = false; - expecteds[U32][F8E8M0FNU] = false; expecteds[U64][PRED] = false; expecteds[U64][S1] = false; expecteds[U64][S2] = false; @@ -482,7 +460,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][C64] = false; expecteds[U64][BF16] = false; expecteds[U64][C128] = false; - expecteds[U64][F4E2M1FN] = false; expecteds[U64][F8E5M2] = false; expecteds[U64][F8E4M3] = false; expecteds[U64][F8E4M3FN] = false; @@ -490,7 +467,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][F8E5M2FNUZ] = false; expecteds[U64][F8E4M3FNUZ] = false; expecteds[U64][F8E3M4] = false; - expecteds[U64][F8E8M0FNU] = false; expecteds[F16][PRED] = false; expecteds[F16][S1] = false; expecteds[F16][S2] = false; @@ -512,7 +488,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][C64] = true; expecteds[F16][BF16] = false; expecteds[F16][C128] = true; - expecteds[F16][F4E2M1FN] = false; expecteds[F16][F8E5M2] = false; expecteds[F16][F8E4M3] = false; expecteds[F16][F8E4M3FN] = false; @@ -520,7 +495,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][F8E5M2FNUZ] = false; expecteds[F16][F8E4M3FNUZ] = false; expecteds[F16][F8E3M4] = false; - expecteds[F16][F8E8M0FNU] = false; expecteds[F32][PRED] = false; expecteds[F32][S1] = false; expecteds[F32][S2] = false; @@ -542,7 +516,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][C64] = true; expecteds[F32][BF16] = false; expecteds[F32][C128] = true; - expecteds[F32][F4E2M1FN] = false; expecteds[F32][F8E5M2] = false; expecteds[F32][F8E4M3] = false; expecteds[F32][F8E4M3FN] = false; @@ -550,7 +523,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][F8E5M2FNUZ] = false; expecteds[F32][F8E4M3FNUZ] = false; expecteds[F32][F8E3M4] = false; - expecteds[F32][F8E8M0FNU] = false; expecteds[F64][PRED] = false; expecteds[F64][S1] = false; expecteds[F64][S2] = false; @@ -572,7 +544,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][C64] = false; expecteds[F64][BF16] = false; expecteds[F64][C128] = true; - expecteds[F64][F4E2M1FN] = false; expecteds[F64][F8E5M2] = false; expecteds[F64][F8E4M3] = false; expecteds[F64][F8E4M3FN] = false; @@ -580,7 +551,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][F8E5M2FNUZ] = false; expecteds[F64][F8E4M3FNUZ] = false; expecteds[F64][F8E3M4] = false; - expecteds[F64][F8E8M0FNU] = false; expecteds[C64][PRED] = false; expecteds[C64][S1] = false; expecteds[C64][S2] = false; @@ -602,7 +572,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][C64] = true; expecteds[C64][BF16] = false; expecteds[C64][C128] = true; - expecteds[C64][F4E2M1FN] = false; expecteds[C64][F8E5M2] = false; expecteds[C64][F8E4M3] = false; expecteds[C64][F8E4M3FN] = false; @@ -610,7 +579,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][F8E5M2FNUZ] = false; expecteds[C64][F8E4M3FNUZ] = false; expecteds[C64][F8E3M4] = false; - expecteds[C64][F8E8M0FNU] = false; expecteds[BF16][PRED] = false; expecteds[BF16][S1] = false; expecteds[BF16][S2] = false; @@ -632,7 +600,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][C64] = true; expecteds[BF16][BF16] = true; expecteds[BF16][C128] = true; - expecteds[BF16][F4E2M1FN] = false; expecteds[BF16][F8E5M2] = false; expecteds[BF16][F8E4M3] = false; expecteds[BF16][F8E4M3FN] = false; @@ -640,7 +607,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][F8E5M2FNUZ] = false; expecteds[BF16][F8E4M3FNUZ] = false; expecteds[BF16][F8E3M4] = false; - expecteds[BF16][F8E8M0FNU] = false; expecteds[C128][PRED] = false; expecteds[C128][S1] = false; expecteds[C128][S2] = false; @@ -662,7 +628,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][C64] = false; expecteds[C128][BF16] = false; expecteds[C128][C128] = true; - expecteds[C128][F4E2M1FN] = false; expecteds[C128][F8E5M2] = false; expecteds[C128][F8E4M3] = false; expecteds[C128][F8E4M3FN] = false; @@ -670,37 +635,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][F8E5M2FNUZ] = false; expecteds[C128][F8E4M3FNUZ] = false; expecteds[C128][F8E3M4] = false; - expecteds[C128][F8E8M0FNU] = false; - expecteds[F4E2M1FN][PRED] = false; - expecteds[F4E2M1FN][S1] = false; - expecteds[F4E2M1FN][S2] = false; - expecteds[F4E2M1FN][S4] = false; - expecteds[F4E2M1FN][S8] = false; - expecteds[F4E2M1FN][S16] = false; - expecteds[F4E2M1FN][S32] = false; - expecteds[F4E2M1FN][S64] = false; - expecteds[F4E2M1FN][U1] = false; - expecteds[F4E2M1FN][U2] = false; - expecteds[F4E2M1FN][U4] = false; - expecteds[F4E2M1FN][U8] = false; - expecteds[F4E2M1FN][U16] = false; - expecteds[F4E2M1FN][U32] = false; - expecteds[F4E2M1FN][U64] = false; - expecteds[F4E2M1FN][F16] = true; - expecteds[F4E2M1FN][F32] = true; - expecteds[F4E2M1FN][F64] = true; - expecteds[F4E2M1FN][C64] = true; - expecteds[F4E2M1FN][BF16] = true; - expecteds[F4E2M1FN][C128] = true; - expecteds[F4E2M1FN][F4E2M1FN] = true; - expecteds[F4E2M1FN][F8E5M2] = true; - expecteds[F4E2M1FN][F8E4M3] = true; - expecteds[F4E2M1FN][F8E4M3FN] = true; - expecteds[F4E2M1FN][F8E4M3B11FNUZ] = false; - expecteds[F4E2M1FN][F8E4M3FNUZ] = false; - expecteds[F4E2M1FN][F8E5M2FNUZ] = false; - expecteds[F4E2M1FN][F8E3M4] = true; - expecteds[F4E2M1FN][F8E8M0FNU] = false; expecteds[F8E5M2][PRED] = false; expecteds[F8E5M2][S1] = false; expecteds[F8E5M2][S2] = false; @@ -722,7 +656,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][C64] = true; expecteds[F8E5M2][BF16] = true; expecteds[F8E5M2][C128] = true; - expecteds[F8E5M2][F4E2M1FN] = false; expecteds[F8E5M2][F8E5M2] = true; expecteds[F8E5M2][F8E4M3] = false; expecteds[F8E5M2][F8E4M3FN] = false; @@ -730,7 +663,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][F8E5M2FNUZ] = false; expecteds[F8E5M2][F8E4M3FNUZ] = false; expecteds[F8E5M2][F8E3M4] = false; - expecteds[F8E5M2][F8E8M0FNU] = false; expecteds[F8E4M3][PRED] = false; expecteds[F8E4M3][S1] = false; expecteds[F8E4M3][S2] = false; @@ -752,7 +684,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3][C64] = true; expecteds[F8E4M3][BF16] = true; expecteds[F8E4M3][C128] = true; - expecteds[F8E4M3][F4E2M1FN] = false; expecteds[F8E4M3][F8E5M2] = false; expecteds[F8E4M3][F8E5M2FNUZ] = false; expecteds[F8E4M3][F8E4M3] = true; @@ -760,7 +691,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3][F8E4M3FNUZ] = false; expecteds[F8E4M3][F8E4M3B11FNUZ] = false; expecteds[F8E4M3][F8E3M4] = false; - expecteds[F8E4M3][F8E8M0FNU] = false; expecteds[F8E4M3FN][PRED] = false; expecteds[F8E4M3FN][S1] = false; expecteds[F8E4M3FN][S2] = false; @@ -782,7 +712,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][C64] = true; expecteds[F8E4M3FN][BF16] = true; expecteds[F8E4M3FN][C128] = true; - expecteds[F8E4M3FN][F4E2M1FN] = false; expecteds[F8E4M3FN][F8E5M2] = false; expecteds[F8E4M3FN][F8E5M2FNUZ] = false; expecteds[F8E4M3FN][F8E4M3] = false; @@ -790,7 +719,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][F8E4M3FNUZ] = false; expecteds[F8E4M3FN][F8E4M3B11FNUZ] = false; expecteds[F8E4M3FN][F8E3M4] = false; - expecteds[F8E4M3FN][F8E8M0FNU] = false; expecteds[F8E4M3B11FNUZ][PRED] = false; expecteds[F8E4M3B11FNUZ][S1] = false; expecteds[F8E4M3B11FNUZ][S2] = false; @@ -812,7 +740,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][C64] = true; expecteds[F8E4M3B11FNUZ][BF16] = true; expecteds[F8E4M3B11FNUZ][C128] = true; - expecteds[F8E4M3B11FNUZ][F4E2M1FN] = false; expecteds[F8E4M3B11FNUZ][F8E5M2] = false; expecteds[F8E4M3B11FNUZ][F8E4M3] = false; expecteds[F8E4M3B11FNUZ][F8E4M3FN] = false; @@ -820,7 +747,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][F8E4M3FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E3M4] = false; - expecteds[F8E4M3B11FNUZ][F8E8M0FNU] = false; expecteds[F8E5M2FNUZ][PRED] = false; expecteds[F8E5M2FNUZ][S1] = false; expecteds[F8E5M2FNUZ][S2] = false; @@ -842,7 +768,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][C64] = true; expecteds[F8E5M2FNUZ][BF16] = true; expecteds[F8E5M2FNUZ][C128] = true; - expecteds[F8E5M2FNUZ][F4E2M1FN] = false; expecteds[F8E5M2FNUZ][F8E5M2] = false; expecteds[F8E5M2FNUZ][F8E4M3] = false; expecteds[F8E5M2FNUZ][F8E4M3FN] = false; @@ -850,7 +775,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][F8E5M2FNUZ] = true; expecteds[F8E5M2FNUZ][F8E4M3FNUZ] = false; expecteds[F8E5M2FNUZ][F8E3M4] = false; - expecteds[F8E5M2FNUZ][F8E8M0FNU] = false; expecteds[F8E4M3FNUZ][PRED] = false; expecteds[F8E4M3FNUZ][S1] = false; expecteds[F8E4M3FNUZ][S2] = false; @@ -872,7 +796,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][C64] = true; expecteds[F8E4M3FNUZ][BF16] = true; expecteds[F8E4M3FNUZ][C128] = true; - expecteds[F8E4M3FNUZ][F4E2M1FN] = false; expecteds[F8E4M3FNUZ][F8E5M2] = false; expecteds[F8E4M3FNUZ][F8E4M3] = false; expecteds[F8E4M3FNUZ][F8E4M3FN] = false; @@ -880,7 +803,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3FNUZ][F8E4M3FNUZ] = true; expecteds[F8E4M3FNUZ][F8E3M4] = false; - expecteds[F8E4M3FNUZ][F8E8M0FNU] = false; expecteds[F8E3M4][PRED] = false; expecteds[F8E3M4][S1] = false; expecteds[F8E3M4][S2] = false; @@ -902,7 +824,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E3M4][C64] = true; expecteds[F8E3M4][BF16] = true; expecteds[F8E3M4][C128] = true; - expecteds[F8E3M4][F4E2M1FN] = false; expecteds[F8E3M4][F8E5M2] = false; expecteds[F8E3M4][F8E5M2FNUZ] = false; expecteds[F8E3M4][F8E4M3] = false; @@ -910,37 +831,6 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E3M4][F8E4M3FNUZ] = false; expecteds[F8E3M4][F8E4M3B11FNUZ] = false; expecteds[F8E3M4][F8E3M4] = true; - expecteds[F8E3M4][F8E8M0FNU] = false; - expecteds[F8E8M0FNU][PRED] = false; - expecteds[F8E8M0FNU][S1] = false; - expecteds[F8E8M0FNU][S2] = false; - expecteds[F8E8M0FNU][S4] = false; - expecteds[F8E8M0FNU][S8] = false; - expecteds[F8E8M0FNU][S16] = false; - expecteds[F8E8M0FNU][S32] = false; - expecteds[F8E8M0FNU][S64] = false; - expecteds[F8E8M0FNU][U1] = false; - expecteds[F8E8M0FNU][U2] = false; - expecteds[F8E8M0FNU][U4] = false; - expecteds[F8E8M0FNU][U8] = false; - expecteds[F8E8M0FNU][U16] = false; - expecteds[F8E8M0FNU][U32] = false; - expecteds[F8E8M0FNU][U64] = false; - expecteds[F8E8M0FNU][F16] = false; - expecteds[F8E8M0FNU][F32] = true; - expecteds[F8E8M0FNU][F64] = true; - expecteds[F8E8M0FNU][C64] = true; - expecteds[F8E8M0FNU][BF16] = true; - expecteds[F8E8M0FNU][C128] = true; - expecteds[F8E8M0FNU][F4E2M1FN] = false; - expecteds[F8E8M0FNU][F8E5M2] = false; - expecteds[F8E8M0FNU][F8E4M3] = false; - expecteds[F8E8M0FNU][F8E4M3FN] = false; - expecteds[F8E8M0FNU][F8E4M3B11FNUZ] = false; - expecteds[F8E8M0FNU][F8E4M3FNUZ] = false; - expecteds[F8E8M0FNU][F8E5M2FNUZ] = false; - expecteds[F8E8M0FNU][F8E3M4] = false; - expecteds[F8E8M0FNU][F8E8M0FNU] = true; for (int from_type_int = PrimitiveType_MIN; from_type_int < PrimitiveType_ARRAYSIZE; ++from_type_int) { @@ -961,7 +851,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { << primitive_util::LowercasePrimitiveTypeName(to_type); } } -} // NOLINT(readability/fn_size) +} } // namespace } // namespace xla diff --git a/xla/python/ifrt/dtype.cc b/xla/python/ifrt/dtype.cc index e1110543cb11a..a79240f51a7e2 100644 --- a/xla/python/ifrt/dtype.cc +++ b/xla/python/ifrt/dtype.cc @@ -32,7 +32,6 @@ std::optional DType::byte_size() const { case kU2: case kS4: case kU4: - case kF4E2M1FN: // Smaller than a byte. return std::nullopt; case kPred: @@ -40,7 +39,6 @@ std::optional DType::byte_size() const { case kU8: case kF8E3M4: case kF8E4M3: - case kF8E8M0FNU: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -79,14 +77,12 @@ std::optional DType::bit_size() const { return 2; case kS4: case kU4: - case kF4E2M1FN: return 4; case kPred: case kS8: case kU8: case kF8E3M4: case kF8E4M3: - case kF8E8M0FNU: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -145,11 +141,9 @@ absl::StatusOr DType::FromProto(const DTypeProto& dtype_proto) { CASE(BF16); CASE(C64); CASE(C128); - CASE(F4E2M1FN); // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. // CASE(F8E3M4); // CASE(F8E4M3); - CASE(F8E8M0FNU); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); @@ -195,11 +189,9 @@ DTypeProto DType::ToProto() const { CASE(BF16); CASE(C64); CASE(C128); - CASE(F4E2M1FN); // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. // CASE(F8E3M4); // CASE(F8E4M3); - CASE(F8E8M0FNU); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); diff --git a/xla/python/ifrt/dtype.h b/xla/python/ifrt/dtype.h index 864cdd1c063ae..d23efc55a1aa1 100644 --- a/xla/python/ifrt/dtype.h +++ b/xla/python/ifrt/dtype.h @@ -88,12 +88,8 @@ class DType { kF8E4M3FNUZ = 25, kF8E5M2 = 19, kF8E5M2FNUZ = 24, - kF8E8M0FNU = 33, - // MX floating point types. - kF4E2M1FN = 32, - - // Next = 34 + // Next = 30 // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind diff --git a/xla/python/ifrt/dtype.proto b/xla/python/ifrt/dtype.proto index 2cf453f26c291..3a2b0df7976d6 100644 --- a/xla/python/ifrt/dtype.proto +++ b/xla/python/ifrt/dtype.proto @@ -70,18 +70,12 @@ message DTypeProto { KIND_F8E4M3FNUZ = 25; KIND_F8E5M2 = 19; KIND_F8E5M2FNUZ = 24; - KIND_F8E8M0FNU = 31; - - // MX floating point types. - KIND_F4E2M1FN = 30; // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind // needs to match xla.PrimitiveType enum, so choose a large enum to avoid // collision. KIND_STRING = 99; - - // Next: 32 } // LINT.ThenChange() Kind kind = 1; diff --git a/xla/python/ifrt/dtype_test.cc b/xla/python/ifrt/dtype_test.cc index 9d3d3105f54e5..57fec6702d277 100644 --- a/xla/python/ifrt/dtype_test.cc +++ b/xla/python/ifrt/dtype_test.cc @@ -42,21 +42,34 @@ TEST(DTypeTest, FromToFromProto) { TEST(DTypeTest, ByteSize) { for (const auto& [kind, byte_size] : std::vector>({ - {DType::kS2, -1}, {DType::kU2, -1}, - {DType::kS4, -1}, {DType::kU4, -1}, - {DType::kPred, 1}, {DType::kS8, 1}, - {DType::kU8, 1}, {DType::kF4E2M1FN, -1}, - {DType::kF8E3M4, 1}, {DType::kF8E4M3, 1}, - {DType::kF8E4M3FN, 1}, {DType::kF8E4M3B11FNUZ, 1}, - {DType::kF8E4M3FNUZ, 1}, {DType::kF8E5M2, 1}, - {DType::kF8E5M2FNUZ, 1}, {DType::kF8E8M0FNU, 1}, - {DType::kS16, 2}, {DType::kU16, 2}, - {DType::kF16, 2}, {DType::kBF16, 2}, - {DType::kS32, 4}, {DType::kU32, 4}, - {DType::kF32, 4}, {DType::kS64, 8}, - {DType::kU64, 8}, {DType::kF64, 8}, - {DType::kC64, 8}, {DType::kC128, 16}, - {DType::kToken, -1}, {DType::kInvalid, -1}, + {DType::kS2, -1}, + {DType::kU2, -1}, + {DType::kS4, -1}, + {DType::kU4, -1}, + {DType::kPred, 1}, + {DType::kS8, 1}, + {DType::kU8, 1}, + {DType::kF8E3M4, 1}, + {DType::kF8E4M3, 1}, + {DType::kF8E4M3FN, 1}, + {DType::kF8E4M3B11FNUZ, 1}, + {DType::kF8E4M3FNUZ, 1}, + {DType::kF8E5M2, 1}, + {DType::kF8E5M2FNUZ, 1}, + {DType::kS16, 2}, + {DType::kU16, 2}, + {DType::kF16, 2}, + {DType::kBF16, 2}, + {DType::kS32, 4}, + {DType::kU32, 4}, + {DType::kF32, 4}, + {DType::kS64, 8}, + {DType::kU64, 8}, + {DType::kF64, 8}, + {DType::kC64, 8}, + {DType::kC128, 16}, + {DType::kToken, -1}, + {DType::kInvalid, -1}, {DType::kString, -1}, })) { EXPECT_EQ(DType(kind).byte_size(), @@ -67,21 +80,34 @@ TEST(DTypeTest, ByteSize) { TEST(DTypeTest, BitSize) { for (const auto& [kind, bit_size] : std::vector>({ - {DType::kS2, 2}, {DType::kU2, 2}, - {DType::kS4, 4}, {DType::kU4, 4}, - {DType::kPred, 8}, {DType::kS8, 8}, - {DType::kU8, 8}, {DType::kF4E2M1FN, 4}, - {DType::kF8E3M4, 8}, {DType::kF8E4M3, 8}, - {DType::kF8E4M3FN, 8}, {DType::kF8E4M3B11FNUZ, 8}, - {DType::kF8E4M3FNUZ, 8}, {DType::kF8E5M2, 8}, - {DType::kF8E5M2FNUZ, 8}, {DType::kF8E8M0FNU, 8}, - {DType::kS16, 16}, {DType::kU16, 16}, - {DType::kF16, 16}, {DType::kBF16, 16}, - {DType::kS32, 32}, {DType::kU32, 32}, - {DType::kF32, 32}, {DType::kS64, 64}, - {DType::kU64, 64}, {DType::kF64, 64}, - {DType::kC64, 64}, {DType::kC128, 128}, - {DType::kToken, -1}, {DType::kInvalid, -1}, + {DType::kS2, 2}, + {DType::kU2, 2}, + {DType::kS4, 4}, + {DType::kU4, 4}, + {DType::kPred, 8}, + {DType::kS8, 8}, + {DType::kU8, 8}, + {DType::kF8E3M4, 8}, + {DType::kF8E4M3, 8}, + {DType::kF8E4M3FN, 8}, + {DType::kF8E4M3B11FNUZ, 8}, + {DType::kF8E4M3FNUZ, 8}, + {DType::kF8E5M2, 8}, + {DType::kF8E5M2FNUZ, 8}, + {DType::kS16, 16}, + {DType::kU16, 16}, + {DType::kF16, 16}, + {DType::kBF16, 16}, + {DType::kS32, 32}, + {DType::kU32, 32}, + {DType::kF32, 32}, + {DType::kS64, 64}, + {DType::kU64, 64}, + {DType::kF64, 64}, + {DType::kC64, 64}, + {DType::kC128, 128}, + {DType::kToken, -1}, + {DType::kInvalid, -1}, {DType::kString, -1}, })) { EXPECT_EQ(DType(kind).bit_size(), diff --git a/xla/python/pjrt_ifrt/pjrt_dtype.cc b/xla/python/pjrt_ifrt/pjrt_dtype.cc index 2af3281a588cc..9c581ec6227ca 100644 --- a/xla/python/pjrt_ifrt/pjrt_dtype.cc +++ b/xla/python/pjrt_ifrt/pjrt_dtype.cc @@ -44,7 +44,6 @@ absl::StatusOr ToPrimitiveType(DType dtype) { CASE(DType::kU16, xla::PrimitiveType::U16); CASE(DType::kU32, xla::PrimitiveType::U32); CASE(DType::kU64, xla::PrimitiveType::U64); - CASE(DType::kF4E2M1FN, xla::PrimitiveType::F4E2M1FN); CASE(DType::kF8E3M4, xla::PrimitiveType::F8E3M4); CASE(DType::kF8E4M3, xla::PrimitiveType::F8E4M3); CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN); @@ -52,7 +51,6 @@ absl::StatusOr ToPrimitiveType(DType dtype) { CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ); CASE(DType::kF8E5M2, xla::PrimitiveType::F8E5M2); CASE(DType::kF8E5M2FNUZ, xla::PrimitiveType::F8E5M2FNUZ); - CASE(DType::kF8E8M0FNU, xla::PrimitiveType::F8E8M0FNU); CASE(DType::kF16, xla::PrimitiveType::F16); CASE(DType::kF32, xla::PrimitiveType::F32); CASE(DType::kBF16, xla::PrimitiveType::BF16); @@ -85,7 +83,6 @@ absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::U16: case xla::PrimitiveType::U32: case xla::PrimitiveType::U64: - case xla::PrimitiveType::F4E2M1FN: case xla::PrimitiveType::F8E3M4: case xla::PrimitiveType::F8E4M3: case xla::PrimitiveType::F8E4M3FN: @@ -93,7 +90,6 @@ absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::F8E4M3FNUZ: case xla::PrimitiveType::F8E5M2: case xla::PrimitiveType::F8E5M2FNUZ: - case xla::PrimitiveType::F8E8M0FNU: case xla::PrimitiveType::F16: case xla::PrimitiveType::F32: case xla::PrimitiveType::BF16: diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index 45baa4abf7935..631b0bcb9b956 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -184,9 +184,6 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = BF16; - } else if (std::is_same()) { - PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); - type = F4E2M1FN; } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E3M4; @@ -208,9 +205,6 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E5M2FNUZ; - } else if (std::is_same()) { - PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); - type = F8E8M0FNU; } else if (std::is_same() || !options.squash_64bit_types) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<0>()); type = primitive_util::NativeToPrimitiveType(); @@ -404,10 +398,6 @@ absl::StatusOr DevicePut(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; - if (dtypes.np_float4_e2m1fn.has_value()) { - (*p)[dtypes.np_float4_e2m1fn->ptr()] = - HandleNumpyScalar; - } if (dtypes.np_float8_e3m4.has_value()) { (*p)[dtypes.np_float8_e3m4->ptr()] = HandleNumpyScalar; @@ -425,10 +415,6 @@ absl::StatusOr DevicePut(nb::handle arg, HandleNumpyScalar; (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = HandleNumpyScalar; - if (dtypes.np_float8_e8m0fnu.has_value()) { - (*p)[dtypes.np_float8_e8m0fnu->ptr()] = - HandleNumpyScalar; - } (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; @@ -609,10 +595,8 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - // (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler; // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; - // (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; diff --git a/xla/python/types.cc b/xla/python/types.cc index 473c082e1425c..50366be350bc0 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -58,7 +58,6 @@ namespace { struct CustomDtypes { nb_dtype bfloat16; - std::optional float4_e2m1fn; std::optional float8_e3m4; std::optional float8_e4m3; nb_dtype float8_e4m3fn; @@ -66,7 +65,6 @@ struct CustomDtypes { nb_dtype float8_e4m3fnuz; nb_dtype float8_e5m2; nb_dtype float8_e5m2fnuz; - std::optional float8_e8m0fnu; std::optional int2; nb_dtype int4; std::optional uint2; @@ -78,10 +76,6 @@ const CustomDtypes& GetCustomDtypes() { nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes"); auto* dtypes = new CustomDtypes; dtypes->bfloat16 = nb_dtype::from_args(ml_dtypes.attr("bfloat16")); - if (nb::hasattr(ml_dtypes, "float4_e2m1fn")) { - dtypes->float4_e2m1fn = - nb_dtype::from_args(ml_dtypes.attr("float4_e2m1fn")); - } if (nb::hasattr(ml_dtypes, "float8_e3m4")) { dtypes->float8_e3m4 = nb_dtype::from_args(ml_dtypes.attr("float8_e3m4")); } @@ -97,10 +91,6 @@ const CustomDtypes& GetCustomDtypes() { nb_dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")); dtypes->float8_e5m2fnuz = nb_dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")); - if (nb::hasattr(ml_dtypes, "float8_e8m0fnu")) { - dtypes->float8_e8m0fnu = - nb_dtype::from_args(ml_dtypes.attr("float8_e8m0fnu")); - } dtypes->int4 = nb_dtype::from_args(ml_dtypes.attr("int4")); dtypes->uint4 = nb_dtype::from_args(ml_dtypes.attr("uint4")); if (nb::hasattr(ml_dtypes, "int2")) { @@ -157,9 +147,6 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { auto* map = new absl::flat_hash_map(); map->emplace(custom_dtypes.bfloat16, BF16); - if (custom_dtypes.float4_e2m1fn.has_value()) { - map->emplace(*custom_dtypes.float4_e2m1fn, F4E2M1FN); - } if (custom_dtypes.float8_e3m4.has_value()) { map->emplace(*custom_dtypes.float8_e3m4, F8E3M4); } @@ -171,9 +158,6 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { map->emplace(custom_dtypes.float8_e4m3fnuz, F8E4M3FNUZ); map->emplace(custom_dtypes.float8_e5m2, F8E5M2); map->emplace(custom_dtypes.float8_e5m2fnuz, F8E5M2FNUZ); - if (custom_dtypes.float8_e8m0fnu.has_value()) { - map->emplace(*custom_dtypes.float8_e8m0fnu, F8E8M0FNU); - } if (custom_dtypes.int2.has_value()) { map->emplace(*custom_dtypes.int2, S2); } @@ -233,11 +217,6 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return to_nb_dtype(NPY_UINT32); case U64: return to_nb_dtype(NPY_UINT64); - case F4E2M1FN: - if (custom_dtypes.float4_e2m1fn.has_value()) { - return *custom_dtypes.float4_e2m1fn; - } - break; case F8E3M4: if (custom_dtypes.float8_e3m4.has_value()) { return *custom_dtypes.float8_e3m4; @@ -258,11 +237,6 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return custom_dtypes.float8_e5m2; case F8E5M2FNUZ: return custom_dtypes.float8_e5m2fnuz; - case F8E8M0FNU: - if (custom_dtypes.float8_e8m0fnu.has_value()) { - return *custom_dtypes.float8_e8m0fnu; - } - break; case BF16: return custom_dtypes.bfloat16; case F16: @@ -333,11 +307,6 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return to_nb_dtype(NPY_COMPLEX64); case ifrt::DType::kC128: return to_nb_dtype(NPY_COMPLEX128); - case ifrt::DType::kF4E2M1FN: - if (custom_dtypes.float4_e2m1fn.has_value()) { - return *custom_dtypes.float4_e2m1fn; - } - break; case ifrt::DType::kF8E3M4: if (custom_dtypes.float8_e3m4.has_value()) { return *custom_dtypes.float8_e3m4; @@ -358,11 +327,6 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return custom_dtypes.float8_e5m2; case ifrt::DType::kF8E5M2FNUZ: return custom_dtypes.float8_e5m2fnuz; - case ifrt::DType::kF8E8M0FNU: - if (custom_dtypes.float8_e8m0fnu.has_value()) { - return *custom_dtypes.float8_e8m0fnu; - } - break; case ifrt::DType::kString: // PEP 3118 code for "pointer to Python Object". We use Python objects // instead of 'U' (Unicode string) or 'V' (raw data) because the latter @@ -416,9 +380,6 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_uint32 = nb::object(numpy.attr("uint32")); dtypes->np_uint64 = nb::object(numpy.attr("uint64")); dtypes->np_bfloat16 = nb::object(ml_dtypes.attr("bfloat16")); - if (nb::hasattr(ml_dtypes, "float4_e2m1fn")) { - dtypes->np_float4_e2m1fn = nb::object(ml_dtypes.attr("float4_e2m1fn")); - } if (nb::hasattr(ml_dtypes, "float8_e3m4")) { dtypes->np_float8_e3m4 = nb::object(ml_dtypes.attr("float8_e3m4")); } @@ -431,9 +392,6 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_float8_e5m2 = nb::object(ml_dtypes.attr("float8_e5m2")); dtypes->np_float8_e4m3fnuz = nb::object(ml_dtypes.attr("float8_e4m3fnuz")); dtypes->np_float8_e5m2fnuz = nb::object(ml_dtypes.attr("float8_e5m2fnuz")); - if (nb::hasattr(ml_dtypes, "float8_e8m0fnu")) { - dtypes->np_float8_e8m0fnu = nb::object(ml_dtypes.attr("float8_e8m0fnu")); - } dtypes->np_float16 = nb::object(numpy.attr("float16")); dtypes->np_float32 = nb::object(numpy.attr("float32")); dtypes->np_float64 = nb::object(numpy.attr("float64")); diff --git a/xla/python/types.h b/xla/python/types.h index babdf5a9bd416..aacfea1a17997 100644 --- a/xla/python/types.h +++ b/xla/python/types.h @@ -81,7 +81,6 @@ struct NumpyScalarTypes { nanobind::object np_uint64; nanobind::object np_bfloat16; // Remove std::optional once the minimum ml_dtypes in JAX is >= 0.5.0. - std::optional np_float4_e2m1fn; std::optional np_float8_e3m4; std::optional np_float8_e4m3; nanobind::object np_float8_e4m3fn; @@ -89,7 +88,6 @@ struct NumpyScalarTypes { nanobind::object np_float8_e4m3fnuz; nanobind::object np_float8_e5m2; nanobind::object np_float8_e5m2fnuz; - std::optional np_float8_e8m0fnu; nanobind::object np_float16; nanobind::object np_float32; nanobind::object np_float64; diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 6a3d259b3589c..219d6704b4f79 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -204,11 +204,9 @@ NB_MODULE(xla_extension, m) { .value("U32", U32) .value("U64", U64) .value("F16", F16) - .value("F4E2M1FN", F4E2M1FN) // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. // .value("F8E3M4", F8E3M4) // .value("F8E4M3", F8E4M3) - .value("F8E8M0FNU", F8E8M0FNU) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) .value("F8E4M3FNUZ", F8E4M3FNUZ) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index c58346f7f3ca9..040c781cd087d 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -280,12 +280,8 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): bfloat16 = ml_dtypes.bfloat16 # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. -# Also, it would be better to conditionally import these based on whether they -# are in the current version of ml_dtypes. -# float4_e2m1fn = ml_dtypes.float4_e2m1fn # float8_e3m4 = ml_dtypes.float8_e3m4 # float8_e4m3 = ml_dtypes.float8_e4m3 -# float8_e8m0fnu = ml_dtypes.float8_e8m0fnu float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz @@ -305,10 +301,8 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType.U32: np.dtype('uint32'), PrimitiveType.U64: np.dtype('uint64'), # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - # PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), - # PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), PrimitiveType.F8E5M2: np.dtype(float8_e5m2), diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index c1bb4dbc3a6fc..cac63a98c1b2d 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -62,10 +62,8 @@ mlir_api_version: int bfloat16: type[numpy.generic] # TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. -# float4_e2m1fn: type[numpy.generic] # float8_e3m4: type[numpy.generic] # float8_e4m3: type[numpy.generic] -# float8_e8m0fnu: type[numpy.generic] float8_e4m3fn: type[numpy.generic] float8_e4m3b11fnuz: type[numpy.generic] float8_e4m3fnuz: type[numpy.generic] diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index 37718e3fa8790..35b4a1ee77964 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -55,10 +55,8 @@ bfloat16 = xla_client.bfloat16 # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. -# float4_e2m1fn = xla_client.float4_e2m1fn # float8_e3m4 = xla_client.float8_e3m4 # float8_e4m3 = xla_client.float8_e4m3 -# float8_e8m0fnu = xla_client.float8_e8m0fnu float8_e4m3fn = xla_client.float8_e4m3fn float8_e4m3fnuz = xla_client.float8_e4m3fnuz float8_e4m3b11fnuz = xla_client.float8_e4m3b11fnuz @@ -191,7 +189,7 @@ def TestFactory(xla_backend, fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] standard_dtypes += fp8_dtypes # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - # standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3, float8_e8m0fnu] + # standard_dtypes += [float8_e3m4, float8_e4m3] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index ee7df05462f7b..2e3862285898f 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -74,7 +74,6 @@ class PrimitiveType(enum.IntEnum): U16: PrimitiveType U32: PrimitiveType U64: PrimitiveType - F4E2M1FN: PrimitiveType F8E3M4: PrimitiveType F8E4M3: PrimitiveType F8E4M3FN: PrimitiveType @@ -82,7 +81,6 @@ class PrimitiveType(enum.IntEnum): F8E4M3FNUZ: PrimitiveType F8E5M2: PrimitiveType F8E5M2FNUZ: PrimitiveType - F8E8M0FNU: PrimitiveType BF16: PrimitiveType F16: PrimitiveType F32: PrimitiveType diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 564cb0a5cf8a0..0faa9f4826398 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -601,10 +601,6 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(&s4_support); FloatSupport u4_support(U4, U8); pipeline.AddPass(&u4_support); - FloatSupport f4e2m1fn_support(F4E2M1FN, F16); - pipeline.AddPass(&f4e2m1fn_support); - FloatSupport f8e8m0fnu_support(F8E8M0FNU, F32); - pipeline.AddPass(&f8e8m0fnu_support); // After canonicalization, there may be more batch dots that can be // simplified. pipeline.AddPass(); diff --git a/xla/service/cpu/onednn_memory_util.h b/xla/service/cpu/onednn_memory_util.h index 90c4f6c82e408..18841d2712dcb 100644 --- a/xla/service/cpu/onednn_memory_util.h +++ b/xla/service/cpu/onednn_memory_util.h @@ -73,7 +73,7 @@ inline dnnl::memory::data_type ToOneDnnDataType(PrimitiveType ptype) { // TODO(intel-tf): properly handle not supported types: // S16, S64, U16, U32, U64, C64, C128, F8E5M2, F8E4M3FN, S4, U4, - // F8E4M3B11FNUZ, F8E4M3, F8E3M4, F4E2M1FN, F8E8M0FNU + // F8E4M3B11FNUZ, F8E4M3, F8E3M4 default: return dt::undef; } diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index 83756d35eb4e3..083f07b8bc8fc 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -864,223 +864,6 @@ llvm::Value* EmitF8e4m3b11fnuzToF16(llvm::Value* f8_value, return f16_value; } -absl::StatusOr EmitF16ToF4e2m1fn(llvm::Value* f16_value, - llvm::IRBuilderBase* b) { - auto i8_const = [&](int val) { - return llvm::ConstantInt::get(b->getInt8Ty(), val); - }; - auto i16_const = [&](int val) { - return llvm::ConstantInt::get(b->getInt16Ty(), val); - }; - constexpr int mantissa_diff = 9; // 10 for F16, 1 for F4 - constexpr int bias_diff = 14; // 15 for F16, 1 for F4 - - // Cast the input value to an integer for bitwise manipulation. - // Get the absolute value of the input (discard the sign). - // f16_bits = bitcast(f16_value, int) - // f16_abs_bits = f16_bits & 0x7FFF - llvm::Value* f16_bits = b->CreateBitCast(f16_value, b->getInt16Ty()); - llvm::Value* f16_abs_bits = b->CreateAnd(f16_bits, i16_const(0x7FFF)); - - // If the input absolute value is >= 7.0 or an infinity, the result saturates - // to max value (6.0). If (0.75 <= input < 1), the result is rounded to 1.0. - // If (0 <= input <= 0.25), the result is rounded to 0.0. - // If the input is NaN, the result is undefined (implemented as minus zero). - // The rest of the cases are handled by the "happy path". - // is_overflow = f16_abs_bits >= 0x1.Cp2 - // is_one = f16_abs_bits >= 0x1.8p-1 (used only if exponent underflows) - // is_zero = f16_abs_bits <= 0x1p-2 (used only if exponent underflows) - // is_nan = f16_abs_bits > 0x7C00 (F16 NaN threshold) - llvm::Value* is_overflow = - b->CreateICmpUGE(f16_abs_bits, i16_const(0x4700)); // 7.0 - llvm::Value* is_one = - b->CreateICmpUGE(f16_abs_bits, i16_const(0x3A00)); // 0.75 - llvm::Value* is_zero = - b->CreateICmpULE(f16_abs_bits, i16_const(0x3400)); // 0.25 - llvm::Value* is_nan = - b->CreateICmpUGT(f16_abs_bits, i16_const(0x7C00)); // inf - - // Truncate the mantissa to 1 bit and the exponent to 3 bits (not 2 bits, as - // the type doesn't have Inf/NaN and can represent unbiased exponent 2). - // This case, as well as the denormal, is handled below. - TF_ASSIGN_OR_RETURN( - llvm::Value * reduced_precision, - EmitReducePrecisionIR( - /*src_ty=*/F16, f16_value, - /*dest_exponent_bits=*/primitive_util::ExponentWidth(F4E2M1FN) + 1, - /*dest_mantissa_bits=*/primitive_util::SignificandWidth(F4E2M1FN) - 1, - /*quiet_nans=*/false, b)); - - // Cast the reduced precision value to an integer for bitwise manipulation. - // Discard the least significant (9) mantissa bits leaving 1 bit. - // Truncate to - // as_int16 = bitcast(reduced_precision, int) - // as_int8 = as_int16 >> (f16_mantissa - f4_mantissa) - llvm::Value* as_int16 = b->CreateBitCast(reduced_precision, b->getInt16Ty()); - llvm::Value* as_int8 = - b->CreateTrunc(b->CreateLShr(as_int16, mantissa_diff), b->getInt8Ty()); - - // Get the sign (0 or 1). - // f4_sign = as_int8 >> 6 - llvm::Value* f4_sign = b->CreateLShr(as_int8, 6); - - // Get exponent and mantissa bits without the sign. - // Important: the mask is 0x3F (not 0x7F), discard bit #6. - // f4_bits = as_int8 & 0x3F - llvm::Value* f4_bits = b->CreateAnd(as_int8, i8_const(0x3F)); - - // Convert F16 exponent to F4 exponent by readjusting the exponent bias. - // This produces the "normal" result, i.e. not Inf or NaN or denormal. - // f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa) - constexpr int f4_exponent_offset = bias_diff << 1; - llvm::Value* f4_normal = b->CreateSub(f4_bits, i8_const(f4_exponent_offset)); - - // If the rounding resulted in zero exponent, the value is incorrect. - // This happens when the input is < 1.0 - // is_underflow = f4_normal <= 1 - llvm::Value* is_underflow = b->CreateICmpSLE(f4_normal, i8_const(1)); - - // Chain of selects that handles the special cases. - // f4_result = - // is_underflow ? (is_one ? 1.0 : (is_zero ? 0.0 : 0.5)) : - // is_overflow ? (is_nan ? -0.0 : 6.0) : - // f4_normal - llvm::Value* f4_result = b->CreateSelect( - is_underflow, - // If underflow, the input is < 1.0; the result is either 0.0, 0.5 or 1.0 - b->CreateSelect(is_one, i8_const(0x2), - b->CreateSelect(is_zero, i8_const(0x0), i8_const(0x1))), - // If overflow, the input is >= 7.0 or infinity or NaN. - b->CreateSelect(is_overflow, - b->CreateSelect(is_nan, i8_const(0x8), i8_const(0x7)), - f4_normal)); - - // Add sign to the resulting value. - // f4_signed_result = (f4_sign << 3) | f4_result - return b->CreateOr(f4_result, b->CreateShl(f4_sign, 3)); -} - -llvm::Value* EmitF4e2m1fnToF16(llvm::Value* f8_value, llvm::IRBuilderBase* b) { - auto i16_const = [&](int val) { - return llvm::ConstantInt::get(b->getInt16Ty(), val); - }; - constexpr int mantissa_diff = 9; // 10 for F16, 1 for F4 - constexpr int bias_diff = 14; // 15 for F16, 1 for F4 - - // The input value is a 8-bit integer, extend it to 16-bit integer. - // as_int16 = bitcast(f8_value, int) - llvm::Value* as_int16 = b->CreateZExt(f8_value, b->getInt16Ty()); - - // Get the sign and shift it to F16 position. - // f4_sign = as_int16 >> 3 - // f16_sign_bit = f4_sign << 15 - llvm::Value* f4_sign = b->CreateLShr(as_int16, 3); - llvm::Value* f16_sign_bit = b->CreateShl(f4_sign, 15); - - // Get exponent and mantissa bits without the sign. - // f4_bits = as_int16 & 0x7 - // f16_bits = f4_bits << (f16_mantissa - f4_mantissa) - llvm::Value* f4_bits = b->CreateAnd(as_int16, i16_const(0x7)); - llvm::Value* f16_bits = b->CreateShl(f4_bits, mantissa_diff); - - // Convert F16 exponent to F4 exponent by readjusting the exponent bias. - // f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa) - constexpr int f16_exponent_offset = bias_diff << 10; - llvm::Value* f16_normal = - b->CreateAdd(f16_bits, i16_const(f16_exponent_offset)); - - // For denormal and zero, the exponent is different. Handle these cases - // separately below. - // is_denorm_or_zero = f4_bits <= 1 - // is_zero = f4_bits == 0 - llvm::Value* is_denorm_or_zero = b->CreateICmpULE(f4_bits, i16_const(1)); - llvm::Value* is_zero = b->CreateICmpEQ(f4_bits, i16_const(0)); - - // Chain of selects that handles the special cases. - // f16_result = is_denorm_or_zero ? (is_zero ? 0.0 : 0.5) : f16_normal - llvm::Value* f16_result = b->CreateSelect( - is_denorm_or_zero, - b->CreateSelect(is_zero, i16_const(0x0000), i16_const(0x3800)), - f16_normal); - - // Add sign to the resulting value. - // f16_signed_result = f16_sign_bit | f16_result - llvm::Value* f16_signed_result = b->CreateOr(f16_result, f16_sign_bit); - return b->CreateBitCast(f16_signed_result, b->getHalfTy()); -} - -llvm::Value* EmitF32ToF8e8m0fnu(llvm::Value* f32_value, - llvm::IRBuilderBase* b) { - auto i32_const = [&](int val) { - return llvm::ConstantInt::get(b->getInt32Ty(), val); - }; - - // Cast the input value to an integer for bitwise manipulation. - // as_int32 = bitcast(f32_value, int) - llvm::Value* as_int32 = b->CreateBitCast(f32_value, b->getInt32Ty()); - - // Check if the input is zero, negative, overflow, infinity or NaN. - // All of these cases cannot be represented in the E8M0 format. - // is_zero_or_negative = as_int32 <= 0 - // is_overflow_or_nan = as_int32 >= 0x1.8p127 - // is_nan = is_zero_or_negative | is_overflow_or_nan - llvm::Value* is_zero_or_negative = b->CreateICmpSLE(as_int32, i32_const(0)); - llvm::Value* is_overflow_or_nan = - b->CreateICmpSGE(as_int32, i32_const(0x7F400000)); // 1.5 * 2^127 - llvm::Value* is_nan = b->CreateOr(is_zero_or_negative, is_overflow_or_nan); - - // Check if the input is a denormal which should round to the minimum value - // (2^-127), as there is no zero value. - // is_denorm = as_int32 <= 0x1p-127 - llvm::Value* is_denorm = - b->CreateICmpULE(as_int32, i32_const(0x400000)); // 1.0 * 2^-127 - - // Round the value (always up) and discard the mantissa. - // rounded = as_int32 + 0x1p-127 - // f8_normal = as_int32 >> f32_mantissa - llvm::Value* rounded = - b->CreateAdd(as_int32, i32_const(0x400000)); // 1.0 * 2^-127 - llvm::Value* f8_normal = b->CreateAShr(rounded, 23); - - // Chain of selects that handles the special cases. - // f8_result = is_nan ? 0xFF : (is_denorm ? 0x00 : f8_normal) - llvm::Value* f8_result = - b->CreateSelect(is_nan, i32_const(0xFF), - b->CreateSelect(is_denorm, i32_const(0x00), f8_normal)); - - // Truncate to the result type. - return b->CreateTrunc(f8_result, b->getInt8Ty()); -} - -llvm::Value* EmitF8e8m0fnuToF32(llvm::Value* f8_value, llvm::IRBuilderBase* b) { - auto i32_const = [&](int val) { - return llvm::ConstantInt::get(b->getInt32Ty(), val); - }; - - // The input value is a 8-bit integer, extend it to 32-bit integer. - // as_int32 = bitcast(f8_value, int) - llvm::Value* as_int32 = b->CreateZExt(f8_value, b->getInt32Ty()); - - // Check if the input is a denormal or NaN. - // is_zero = as_int32 == 0x00 - // is_nan = as_int32 == 0xFF - llvm::Value* is_zero = b->CreateICmpEQ(as_int32, i32_const(0)); - llvm::Value* is_nan = b->CreateICmpEQ(as_int32, i32_const(0xFF)); - - // Shift exponent to the left for the normal case. - // f32_normal = as_int32 << mantissa_diff - llvm::Value* f32_normal = b->CreateShl(as_int32, 23); - - // Chain of selects that handles the special cases. - // f32_result = is_nan ? 0x7FC00000 : (is_zero ? 0x1p-127 : f32_normal) - llvm::Value* f32_result = b->CreateSelect( - is_nan, i32_const(0x7FC00000), - b->CreateSelect(is_zero, i32_const(0x400000), f32_normal)); - - // Bitcast integer bits to the result type. - return b->CreateBitCast(f32_result, b->getFloatTy()); -} - llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, PrimitiveType from_type, PrimitiveType to_type, llvm::Module* module, @@ -1175,18 +958,6 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), b_); } - if (to_type == F4E2M1FN) { - return EmitF16ToF4e2m1fn( - EmitIntegralToFloating(operand_value, from_type, F16, module_, - b_), - b_); - } - if (to_type == F8E8M0FNU) { - return EmitF32ToF8e8m0fnu( - EmitIntegralToFloating(operand_value, from_type, F32, module_, - b_), - b_); - } if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz( F16, @@ -1392,29 +1163,10 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } - if (from_type == F4E2M1FN) { - TF_RET_CHECK(to_type != F4E2M1FN); - operand_value = EmitF4e2m1fnToF16(operand_value, b_); - from_type = F16; - if (from_type == to_type) { - return operand_value; - } - } - if (from_type == F8E8M0FNU) { - TF_RET_CHECK(to_type != F8E8M0FNU); - operand_value = EmitF8e8m0fnuToF32(operand_value, b_); - from_type = F32; - if (from_type == to_type) { - return operand_value; - } - } if (from_type == F8E5M2FNUZ || from_type == F8E4M3FNUZ) { TF_RET_CHECK(to_type != from_type); PrimitiveType cast_type = primitive_util::IsFloatingPointType(to_type) ? to_type : F16; - if (to_type == F8E8M0FNU || to_type == F4E2M1FN) { - cast_type = F32; - } TF_ASSIGN_OR_RETURN(operand_value, EmitF8fnuzToFloating(from_type, operand_value, cast_type, b_, module_)); @@ -1497,24 +1249,6 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitF16ToF8e4m3b11fnuz(operand_value, b_); } - if (to_type == F4E2M1FN) { - // Cast to F16 first. Casts to F4E2M1FN must be from F16. - if (from_type != F16) { - operand_value = b_->CreateFPCast( - operand_value, - llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext())); - } - return EmitF16ToF4e2m1fn(operand_value, b_); - } - if (to_type == F8E8M0FNU) { - // Cast to F32 first. Casts to F8E8M0FNU must be from F32. - if (from_type != F32) { - operand_value = b_->CreateFPCast( - operand_value, - llvm_ir::PrimitiveTypeToIrType(F32, module_->getContext())); - } - return EmitF32ToF8e8m0fnu(operand_value, b_); - } if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz(from_type, operand_value, to_type, b_); } @@ -2075,12 +1809,6 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( } else if (operand_type == F8E4M3FN) { lhs_value = EmitF8e4m3fnToF16(lhs_value, b_); rhs_value = EmitF8e4m3fnToF16(rhs_value, b_); - } else if (operand_type == F4E2M1FN) { - lhs_value = EmitF4e2m1fnToF16(lhs_value, b_); - rhs_value = EmitF4e2m1fnToF16(rhs_value, b_); - } else if (operand_type == F8E8M0FNU) { - lhs_value = EmitF8e8m0fnuToF32(lhs_value, b_); - rhs_value = EmitF8e8m0fnuToF32(rhs_value, b_); } else if (operand_type == F8E5M2FNUZ || operand_type == F8E4M3FNUZ) { TF_ASSIGN_OR_RETURN( lhs_value, @@ -3935,8 +3663,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( primitive_util::IsFloatingPointType(component_element_type)) << component_element_type; llvm::Type* float_ir_type; - if (component_element_type == F8E4M3FNUZ || - component_element_type == F8E5M2FNUZ) { + if (component_element_type == F8E4M3FNUZ) { + float_ir_type = + llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext()); + } else if (component_element_type == F8E5M2FNUZ) { float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext()); } else { diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index 0d906f47b4c47..f947aa8ada14c 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -100,10 +100,9 @@ class ElementalIrEmitterExecutionTypedTest }; using FloatTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ElementalIrEmitterExecutionTypedTest, FloatTypes); @@ -616,9 +615,7 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { std::is_same() || std::is_same() || std::is_same() || - std::is_same() || - std::is_same() || - std::is_same()) { + std::is_same()) { GTEST_SKIP() << "Skipping test for type " << tname; } const auto hlo_text = absl::StrReplaceAll(R"( @@ -633,10 +630,6 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { TYPED_TEST(ElementalIrEmitterExecutionTypedTest, BatchDotFloat) { auto tname = this->TypeName(); - if (std::is_same() || - std::is_same()) { - GTEST_SKIP() << "Skipping test for type " << tname; - } const auto hlo_text = absl::StrReplaceAll(R"( HloModule matmul diff --git a/xla/service/float8_fnuz_ir_emitter.cc b/xla/service/float8_fnuz_ir_emitter.cc index e0be95da5f668..4afb96362cf86 100644 --- a/xla/service/float8_fnuz_ir_emitter.cc +++ b/xla/service/float8_fnuz_ir_emitter.cc @@ -40,8 +40,6 @@ namespace { absl::StatusOr PrimitiveTypeToAPFloatSemantics( PrimitiveType type) { switch (type) { - case F4E2M1FN: - return &llvm::APFloat::Float4E2M1FN(); case F8E3M4: return &llvm::APFloat::Float8E3M4(); case F8E4M3: @@ -56,8 +54,6 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( return &llvm::APFloat::Float8E5M2(); case F8E5M2FNUZ: return &llvm::APFloat::Float8E5M2FNUZ(); - case F8E8M0FNU: - return &llvm::APFloat::Float8E8M0FNU(); case BF16: return &llvm::APFloat::BFloat(); case F16: @@ -76,8 +72,6 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilderBase* b, PrimitiveType type) { switch (type) { - case F4E2M1FN: - return b->getIntNTy(4); case F8E3M4: case F8E4M3: case F8E4M3B11FNUZ: @@ -85,7 +79,6 @@ absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilderBase* b, case F8E4M3FNUZ: case F8E5M2: case F8E5M2FNUZ: - case F8E8M0FNU: return b->getInt8Ty(); case BF16: return b->getBFloatTy(); @@ -656,14 +649,8 @@ absl::StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, llvm::ConstantInt::get(b->getInt8Ty(), 0x0u), sign); // Bitwise or the sign bit back in. - int shift = output_type_bit_width - BitWidth(input_type); - if (shift >= 0) { - sign = b->CreateZExt(sign, output_int_type); - sign = b->CreateShl(sign, shift); - } else { - sign = b->CreateLShr(sign, -shift); - sign = b->CreateTrunc(sign, output_int_type); - } + sign = b->CreateZExt(sign, output_int_type); + sign = b->CreateShl(sign, output_type_bit_width - BitWidth(input_type)); llvm::Value* result = b->CreateOr(sign, result_abs); // Bitcast to the output type. diff --git a/xla/service/gpu/fusions/triton/triton_support_test.cc b/xla/service/gpu/fusions/triton/triton_support_test.cc index 897bc03d78315..5d0c696ccc980 100644 --- a/xla/service/gpu/fusions/triton/triton_support_test.cc +++ b/xla/service/gpu/fusions/triton/triton_support_test.cc @@ -550,18 +550,9 @@ INSTANTIATE_TEST_SUITE_P( using ReduceTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; -static absl::string_view init_value(PrimitiveType dtype) { - if (dtype == C64 || dtype == C128) { - return "(0, 0)"; - } else if (dtype == F8E8M0FNU) { - return "1e-40"; - } else { - return "0"; - } -} - TEST_P(ReduceTest, IsTritonSupportedReduction) { auto [data_type, opcode, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( add { @@ -576,7 +567,7 @@ ENTRY triton_computation { ROOT reduce = $0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=add })", - "$0", init_value(data_type)); + "$0", dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -608,6 +599,7 @@ TEST_P( ReduceTest, UnsupportedReduceWithMoreThanOneReduceDimensionsFailsGracefullyWithTriton) { auto [data_type, opcode, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( add { @@ -622,7 +614,7 @@ ENTRY triton_computation { ROOT reduce = $0[2] reduce(parameter_0, constant_0), dimensions={1,2}, to_apply=add })", - "$0", init_value(data_type)); + "$0", dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -632,6 +624,7 @@ ENTRY triton_computation { TEST_P(ReduceTest, IsTritonSupportedReduceWithNonLastReduceDimension) { auto [data_type, opcode, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( add { @@ -645,7 +638,7 @@ ENTRY triton_computation { constant_0 = $0[] constant($1) ROOT reduce = $0[127] reduce(parameter_0, constant_0), dimensions={0}, to_apply=add })", - "$0", init_value(data_type)); + "$0", dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -656,6 +649,7 @@ ENTRY triton_computation { TEST_P(ReduceTest, UnsupportedReduceWithMoreThanOneOperandsFailsGracefullyWithTriton) { auto [data_type, opcode, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( add { @@ -676,7 +670,7 @@ ENTRY triton_computation { dimensions={1}, to_apply=add ROOT reduce = $0[125] get-tuple-element(tuple), index=0 })", - "$0", init_value(data_type)); + "$0", dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -707,6 +701,7 @@ ENTRY triton_computation { TEST_P(ReduceTest, UnsupportedReductionComputationFailsGracefullyWithTriton) { auto [data_type, opcode, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( custom_call { @@ -721,7 +716,7 @@ ENTRY triton_computation { ROOT reduce = $0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=custom_call })", - "$0", init_value(data_type)); + "$0", dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -745,6 +740,7 @@ using ReductionComputationTest = // computation and in regular HLO. See triton_support.cc for more details. TEST_P(ReductionComputationTest, DifferentBinaryOps) { auto [data_type, opcode, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute( R"( reduce_computation { @@ -759,7 +755,7 @@ ENTRY triton_computation { ROOT reduce = $0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=reduce_computation })", - "$0", HloOpcodeString(opcode), init_value(data_type)); + "$0", HloOpcodeString(opcode), dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -1119,12 +1115,13 @@ TEST_P(ConstantTest, ConstantEffectiveScalar) { // The IsTritonSupportedReduction effectively tests the scalar constant // support. auto [data_type, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( ENTRY triton_computation { ROOT const = $0[1,1] constant({{$1}}) })", - "$0", init_value(data_type)); + "$0", dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( kHloTestTemplate, data_type, @@ -1136,12 +1133,13 @@ TEST_P(ConstantTest, Constant2D) { // The IsTritonSupportedReduction effectively tests the scalar constant // support. auto [data_type, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( ENTRY triton_computation { ROOT const = $0[3,3] constant({{$1,$1,$1},{$1,$1,$1},{$1,$1,$1}}) })", - "$0", init_value(data_type)); + "$0", dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( kHloTestTemplate, data_type, diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 666c187998cb6..faeaa7a6c4667 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1478,8 +1478,6 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( const GpuFloatSupport f8e3m4_support(gpu_version, F8E3M4, F16); const GpuFloatSupport s4_support(gpu_version, S4, S8); const GpuFloatSupport u4_support(gpu_version, U4, U8); - const GpuFloatSupport f4e2m1fn_support(gpu_version, F4E2M1FN, F16); - const GpuFloatSupport f8e8m0fnu_support(gpu_version, F8E8M0FNU, F32); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = pipeline.AddPass("float_normalization"); @@ -1493,8 +1491,6 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( sub_pipeline.AddPass(&f8e3m4_support); sub_pipeline.AddPass(&s4_support); sub_pipeline.AddPass(&u4_support); - sub_pipeline.AddPass(&f4e2m1fn_support); - sub_pipeline.AddPass(&f8e8m0fnu_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_allow_excess_precision()) { sub_pipeline.AddPass(); diff --git a/xla/service/gpu/tests/float_conversions_test.cc b/xla/service/gpu/tests/float_conversions_test.cc index 6e0e14e320a7f..16383324dfb01 100644 --- a/xla/service/gpu/tests/float_conversions_test.cc +++ b/xla/service/gpu/tests/float_conversions_test.cc @@ -29,10 +29,9 @@ class FloatConversionParamTest INSTANTIATE_TEST_SUITE_P(FloatConversionParamSuite, FloatConversionParamTest, ::testing::Values("f64", "f32", "f16", "bf16", - "f4e2m1fn", "f8e5m2", "f8e5m2fnuz", - "f8e4m3", "f8e4m3fn", "f8e4m3fnuz", - "f8e4m3b11fnuz", "f8e3m4", - "f8e8m0fnu")); + "f8e5m2", "f8e5m2fnuz", "f8e4m3", + "f8e4m3fn", "f8e4m3fnuz", + "f8e4m3b11fnuz", "f8e3m4")); TEST_P(FloatConversionParamTest, FloatToF16) { auto type_name = GetParam(); diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 38dfd05667e00..88823f1dd9e5c 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -2972,10 +2972,9 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { Layout::Equal().IgnoreTiles().IgnoreMemorySpace(); if (instruction->opcode() == HloOpcode::kConvert || instruction->opcode() == HloOpcode::kCompare || - instruction->opcode() == HloOpcode::kIsFinite || (instruction->opcode() == HloOpcode::kSelect && operand_shape.element_type() == PRED)) { - // Some instructions can change element_size_in_bits + // Convert and Compare instructions can change element_size_in_bits // Select instructions ignore element_size_in_bits for predicate equal_predicate.IgnoreElementSize(); } else if (instruction->opcode() == HloOpcode::kDynamicSlice || diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index b937dbc1500b6..d56172dd4b254 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -199,8 +199,6 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, case S16: case U16: return llvm::Type::getInt16Ty(context); - case F4E2M1FN: - return llvm::Type::getIntNTy(context, 4); case F8E5M2: case F8E5M2FNUZ: case F8E4M3: @@ -208,7 +206,6 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, case F8E4M3B11FNUZ: case F8E4M3FNUZ: case F8E3M4: - case F8E8M0FNU: // We represent F8 as an int since there is no LLVM F8 dtype. return llvm::Type::getInt8Ty(context); case BF16: diff --git a/xla/stream_executor/data_type.h b/xla/stream_executor/data_type.h index e3e7d1f17e312..f5246389e485c 100644 --- a/xla/stream_executor/data_type.h +++ b/xla/stream_executor/data_type.h @@ -37,10 +37,6 @@ struct ToDataType; // Note: If you add a new specialization below, make sure to add the // corresponding definition in stream_executor/dnn.cc. template <> -struct ToDataType { - static constexpr DataType value = DataType::kF4E2M1FN; -}; -template <> struct ToDataType { static constexpr DataType value = DataType::kF8E3M4; }; @@ -65,10 +61,6 @@ struct ToDataType { static constexpr DataType value = DataType::kF8E5M2FNUZ; }; template <> -struct ToDataType { - static constexpr DataType value = DataType::kF8E8M0FNU; -}; -template <> struct ToDataType { static constexpr DataType value = DataType::kFloat; }; diff --git a/xla/stream_executor/dnn.cc b/xla/stream_executor/dnn.cc index 24851e56d75ed..6b7a87d80b3ae 100644 --- a/xla/stream_executor/dnn.cc +++ b/xla/stream_executor/dnn.cc @@ -69,14 +69,12 @@ bool ProtoMapsEqual(const google::protobuf::Map& x, } // namespace -constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; -constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; diff --git a/xla/stream_executor/gpu/gpu_blas_lt.cc b/xla/stream_executor/gpu/gpu_blas_lt.cc index 182af599af9e5..6aee86bf2cbc1 100644 --- a/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -56,10 +56,6 @@ absl::StatusOr AsBlasDataType(PrimitiveType dtype) { return DataType::kF8E4M3FNUZ; case PrimitiveType::F8E3M4: return DataType::kF8E3M4; - case PrimitiveType::F4E2M1FN: - return DataType::kF4E2M1FN; - case PrimitiveType::F8E8M0FNU: - return DataType::kF8E8M0FNU; case PrimitiveType::S8: return DataType::kInt8; case PrimitiveType::F16: @@ -97,10 +93,6 @@ absl::StatusOr AsXlaPrimitiveType(DataType dtype) { return PrimitiveType::F8E4M3FNUZ; case DataType::kF8E3M4: return PrimitiveType::F8E3M4; - case DataType::kF4E2M1FN: - return PrimitiveType::F4E2M1FN; - case DataType::kF8E8M0FNU: - return PrimitiveType::F8E8M0FNU; case DataType::kInt8: return PrimitiveType::S8; case DataType::kHalf: @@ -162,8 +154,6 @@ absl::StatusOr GetBlasComputationType( case PrimitiveType::F8E5M2FNUZ: // fall-through case PrimitiveType::F8E4M3FNUZ: // fall-through case PrimitiveType::F8E3M4: // fall-through - case PrimitiveType::F4E2M1FN: // fall-through - case PrimitiveType::F8E8M0FNU: // fall-through case PrimitiveType::F16: // fall-through case PrimitiveType::BF16: // Accumulate in f32 precision. diff --git a/xla/stream_executor/rocm/hip_blas_utils.cc b/xla/stream_executor/rocm/hip_blas_utils.cc index 8864476bf0d82..e5730121addd8 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/xla/stream_executor/rocm/hip_blas_utils.cc @@ -39,10 +39,8 @@ hipDataType AsHipblasDataType(blas::DataType type) { case blas::DataType::kF8E4M3: case blas::DataType::kF8E4M3FN: case blas::DataType::kF8E3M4: - case blas::DataType::kF4E2M1FN: - case blas::DataType::kF8E8M0FNU: - LOG(FATAL) << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN, " - "F8E3M4, F4E2M1FN and F8E8M0FNU"; + LOG(FATAL) + << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN and F8E3M4"; #if TF_ROCM_VERSION >= 60000 case blas::DataType::kF8E5M2FNUZ: return HIP_R_8F_E5M2_FNUZ; diff --git a/xla/tests/BUILD b/xla/tests/BUILD index db70ca4d0c831..6f82474e3561f 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -863,14 +863,12 @@ xla_test( "//xla:shape_util", "//xla:test", "//xla:types", - "//xla:util", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/base", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", "@ml_dtypes//:float8", "@tsl//tsl/platform:ml_dtypes", ] + if_rocm_is_configured([ diff --git a/xla/tests/array_elementwise_ops_test.cc b/xla/tests/array_elementwise_ops_test.cc index f2fb97be51f68..c12ce79a06e8f 100644 --- a/xla/tests/array_elementwise_ops_test.cc +++ b/xla/tests/array_elementwise_ops_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include #include -#include #include "absl/base/casts.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -48,7 +47,6 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/types.h" -#include "xla/util.h" #include "tsl/platform/ml_dtypes.h" #if TENSORFLOW_USE_ROCM @@ -95,20 +93,6 @@ std::pair, std::vector> AllSignedPairs( return {xs, ys}; } -template -void AddNegativeValuesMaybeRemoveZero(std::vector& values) { - values.reserve(values.size() * 2); - if (!has_zero_v) { - values.erase(values.begin()); - } - for (size_t i = 0, n = values.size(); i < n; ++i) { - auto neg = -values[i]; - if (SignAndMagnitude(neg).first) { - values.push_back(neg); - } - } -} - class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: static constexpr float kEpsF32 = std::numeric_limits::epsilon(); @@ -1387,7 +1371,14 @@ class TotalOrderTest : public ClientLibraryTestBase { values.push_back(Eigen::numext::abs(std::numeric_limits::quiet_NaN())); } #endif - AddNegativeValuesMaybeRemoveZero(values); + values.reserve(values.size() * 2); + for (size_t i = 0, n = values.size(); i < n; ++i) { + auto value = values[i]; + auto neg = -value; + if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { + values.push_back(neg); + } + } std::vector lhs_data; std::vector rhs_data; lhs_data.reserve(values.size() * values.size()); @@ -1432,24 +1423,19 @@ class TotalOrderTest : public ClientLibraryTestBase { } }; -using Types = - ::testing::Types; + float>; TYPED_TEST_SUITE(TotalOrderTest, Types); @@ -1476,7 +1462,13 @@ TYPED_TEST(TotalOrderTest, LargeMagnitudeVsNaN) { if constexpr (std::numeric_limits::has_infinity) { values.push_back(std::numeric_limits::infinity()); } - AddNegativeValuesMaybeRemoveZero(values); + for (size_t i = 0, n = values.size(); i < n; ++i) { + auto value = values[i]; + auto neg = -value; + if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { + values.push_back(neg); + } + } auto lhs = ConstantR1(&builder, values); auto rhs = ConstantR1( &builder, diff --git a/xla/tests/constants_test.cc b/xla/tests/constants_test.cc index 9e191a30b405a..9650077ed57b2 100644 --- a/xla/tests/constants_test.cc +++ b/xla/tests/constants_test.cc @@ -52,13 +52,7 @@ using FloatTypes = ::testing::Types; + tsl::float8_e5m2fnuz>; TYPED_TEST_SUITE(ConstantsFloatTest, FloatTypes); diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index a8e370ad50c0d..4f06ea0cc290c 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -54,17 +54,9 @@ class ConvertTestT : public ConvertTest { using ConvertTest::ConvertTest; }; using FloatingPointTypeList = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ConvertTestT, FloatingPointTypeList); template @@ -749,11 +741,10 @@ XLA_TYPED_TEST(ConvertTestT, ConvertFPToPred) { XlaBuilder builder(this->TestName()); using FP = TypeParam; - auto a = ConstantR1(&builder, {FP{0.0}, FP{0.5}, FP{2.0}, FP{-0.0}}); + auto a = ConstantR1(&builder, {FP{0.0}, FP{0.25}, FP{2.0}, FP{-0.0}}); ConvertElementType(a, PRED); - bool zero_pred = !has_zero_v; - std::array expected = {zero_pred, true, true, zero_pred}; + std::array expected = {false, true, true, false}; this->template ComputeAndCompareR1(&builder, expected, {}); } @@ -1934,283 +1925,5 @@ XLA_TYPED_TEST(ConvertTestF16, ConvertF8e3m4F16RoundtripExhaustive4) { this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -// ----- F4E2M1FN - -XLA_TEST_F(ConvertTest, DISABLED_ON_TPU(ConvertF16F4e2m1fnRoundtrip)) { - // Convert from FP16 to FP4, then back to FP16. - XlaBuilder builder(TestName()); - float inf = std::numeric_limits::infinity(); - - struct TestCase { - float input; - float expected_roundtrip; - } test_cases[] = { - // clang-format off - {0.0, 0.0}, - {-0.0, -0.0}, - {1.0, 1.0}, - {-1.0, -1.0}, - {inf, 0x1.8p2}, - // clang-format on - {0x1.4p0, 0x1p0}, // Round-to-even down - {0x1.Cp0, 0x1p1}, // Round-to-even up - {0x1.8p2, 0x1.8p2}, // Max value - {0x1.BFCp2, 0x1.8p2}, // Largest number that doesn't overflow - {0x1.Cp2, 0x1.8p2}, // Smallest number that overflows - {0x1p3, 0x1.8p2}, // Overflow - {0x1p0, 0x1p0}, // Smallest F8 normal - {0x1.8p-1, 0x1p0}, // Smallest number rounding up to normal - - // Denormal tests - {0x1.0p-1, 0x1.0p-1}, // Denormal without rounding - {0x1.8p-1, 0x1.0p0}, // Round-to-even up - {0x1.6p-1, 0x1.0p-1}, // Round-to-nearest down - {0x1.Ep-1, 0x1.0p0}, // Round-to-nearest up - {0x1p-2, 0}, // Largest number that underflows - {0x1.004p-2, 0x1p-1}, // Smallest number that doesn't underflow - {0x1.7FCp-1, 0x1p-1}, // Largest number that rounds to denormal - }; - - std::vector inputs; - std::vector expected_roundtrip; - for (auto test_case : test_cases) { - inputs.push_back(Eigen::half{test_case.input}); - expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); - } - - auto f4 = - ConvertElementType(ConstantR1(&builder, inputs), F4E2M1FN); - ConvertElementType(f4, F16); - ComputeAndCompareR1(&builder, expected_roundtrip, {}, - ErrorSpec(0.)); -} - -XLA_TEST_F(ConvertTest, - DISABLED_ON_TPU(DISABLED_ON_CPU(ConvertF32F4e2m1fnRoundtrip))) { - // Convert from FP32 to FP4, then back to FP32. - XlaBuilder builder(TestName()); - float inf = std::numeric_limits::infinity(); - - struct TestCase { - float input; - float expected_roundtrip; - } test_cases[] = { - // clang-format off - {0.0, 0.0}, - {-0.0, -0.0}, - {1.0, 1.0}, - {-1.0, -1.0}, - {inf, 0x1.8p2}, - // clang-format on - {0x1.4p0, 0x1p0}, // Round-to-even down - {0x1.Cp0, 0x1p1}, // Round-to-even up - {0x1.8p2, 0x1.8p2}, // Max value - {0x1.BFFFFEp2, 0x1.8p2}, // Largest number that doesn't overflow - {0x1.Cp2, 0x1.8p2}, // Smallest number that overflows - {0x1p3, 0x1.8p2}, // Overflow - {0x1p0, 0x1p0}, // Smallest F8 normal - {0x1.8p-1, 0x1p0}, // Smallest number rounding up to normal - - // Denormal tests - {0x1.0p-1, 0x1.0p-1}, // Denormal without rounding - {0x1.8p-1, 0x1.0p0}, // Round-to-even up - {0x1.6p-1, 0x1.0p-1}, // Round-to-nearest down - {0x1.Ep-1, 0x1.0p0}, // Round-to-nearest up - {0x1p-2, 0}, // Largest number that underflows - {0x1.000002p-2, 0x1p-1}, // Smallest number that doesn't underflow - {0x1.7FFFFEp-1, 0x1p-1}, // Largest number that rounds to denormal - }; - - std::vector inputs; - std::vector expected_roundtrip; - for (auto test_case : test_cases) { - inputs.push_back(test_case.input); - expected_roundtrip.push_back(test_case.expected_roundtrip); - } - - auto f4 = ConvertElementType(ConstantR1(&builder, inputs), F4E2M1FN); - ConvertElementType(f4, F32); - ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); -} - -XLA_TYPED_TEST(ConvertTestT, - DISABLED_ON_TPU(ConvertF4e2m1fnRoundtripExhaustive)) { - // Convert from FP4 to supported floating point type, then back to FP4. - XlaBuilder builder(this->TestName()); - - using From = tsl::float4_e2m1fn; - std::vector all_f4; - for (int i = 0; i < 16; i++) { - all_f4.push_back(Eigen::numext::bit_cast(static_cast(i))); - } - - xla::XlaOp all_f4_as_fp = - ConvertElementType(ConstantR1(&builder, all_f4), - primitive_util::NativeToPrimitiveType()); - ConvertElementType(all_f4_as_fp, F4E2M1FN); - this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); -} - -XLA_TYPED_TEST(ConvertTestT, - DISABLED_ON_TPU(ConvertF4e2m1fnRoundtripExhaustive2)) { - // Convert from supported floating point type to FP4. - XlaBuilder builder(this->TestName()); - - std::vector all_f4; - for (int i = 0; i < 16; i++) { - all_f4.push_back(static_cast( - Eigen::numext::bit_cast(static_cast(i)))); - } - - ConvertElementType(ConstantR1(&builder, all_f4), F4E2M1FN); - this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); -} - -XLA_TYPED_TEST(ConvertTestT, - DISABLED_ON_TPU(ConvertF4e2m1fnRoundtripExhaustive3)) { - // Convert from FP4 to supported floating point type. - XlaBuilder builder(this->TestName()); - - using From = tsl::float4_e2m1fn; - std::vector all_f4; - for (int i = 0; i < 16; i++) { - all_f4.push_back(Eigen::numext::bit_cast(static_cast(i))); - } - - ConvertElementType(ConstantR1(&builder, all_f4), - primitive_util::NativeToPrimitiveType()); - this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); -} - -XLA_TYPED_TEST(ConvertTestF16, - DISABLED_ON_TPU(ConvertF4e2m1fnF16RoundtripExhaustive4)) { - // Convert from (B)F16 to FP4. - XlaBuilder builder(this->TestName()); - - std::vector all_f16; - for (int i = 0; i < 65536; i++) { - all_f16.push_back( - Eigen::numext::bit_cast(static_cast(i))); - } - - ConvertElementType(ConstantR1(&builder, all_f16), F4E2M1FN); - this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); -} - -// ----- F8E8M0FNU - -XLA_TEST_F(ConvertTest, DISABLED_ON_TPU(ConvertF32F8e8m0fnuRoundtrip)) { - // Convert from FP32 to FP8, then back to FP32. - XlaBuilder builder(TestName()); - float nan = std::numeric_limits::quiet_NaN(); - float inf = std::numeric_limits::infinity(); - - struct TestCase { - float input; - float expected_roundtrip; - } test_cases[] = { - // clang-format off - {0.0, nan}, // No zero values - {-0.0, nan}, - {1.0, 1.0}, - {-1.0, nan}, // No negative values - {nan, nan}, - {inf, nan}, - // clang-format on - {0x1.8p1, 0x1p2}, // Round-to-even up - {0x1.8p2, 0x1p3}, // Round-to-even up (always rounds up) - {0x1p127, 0x1p127}, // Max value - {0x1.7FFFFEp127, 0x1p127}, // Largest number that doesn't overflow - {0x1.8p127, nan}, // Smallest number that overflows - {0x1.FFFFFEp127, nan}, // Overflow - {0x1p-126, 0x1p-126}, // Smallest F8 normal - {0x0.800002p-126, 0x1p-126}, // Smallest number rounding up to normal - }; - - std::vector inputs; - std::vector expected_roundtrip; - for (auto test_case : test_cases) { - inputs.push_back(test_case.input); - expected_roundtrip.push_back(test_case.expected_roundtrip); - } - - auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E8M0FNU); - ConvertElementType(f8, F32); - ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); -} - -XLA_TYPED_TEST(ConvertTestT, - DISABLED_ON_TPU(ConvertF8e8m0fnuRoundtripExhaustive)) { - // Convert from FP8 to supported floating point type, then back to FP8. - XlaBuilder builder(this->TestName()); - - using From = tsl::float8_e8m0fnu; - std::vector all_f8; - for (int i = 0; i < 256; i++) { - all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); - } - - xla::XlaOp all_f8_as_fp = - ConvertElementType(ConstantR1(&builder, all_f8), - primitive_util::NativeToPrimitiveType()); - ConvertElementType(all_f8_as_fp, F8E8M0FNU); - this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); -} - -XLA_TYPED_TEST(ConvertTestT, - DISABLED_ON_TPU(ConvertF8e8m0fnuRoundtripExhaustive2)) { - if (this->client_->platform()->Name() == "Host") { - // This test is disabled on CPU, as converting 0x1p-127 from double to float - // using CVTSD2SS on x64 results in an underflow (even though the result is - // representable as denormalized float32). - if (std::is_same_v) { - GTEST_SKIP() << "Skipping test for double precision floating point that " - "loses denormal value during conversion"; - } - } - // Convert from supported floating point type to FP8. - XlaBuilder builder(this->TestName()); - - std::vector all_f8; - for (int i = 0; i < 256; i++) { - all_f8.push_back(static_cast( - Eigen::numext::bit_cast(static_cast(i)))); - } - - ConvertElementType(ConstantR1(&builder, all_f8), F8E8M0FNU); - this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); -} - -XLA_TYPED_TEST(ConvertTestT, - DISABLED_ON_TPU(ConvertF8e8m0fnuRoundtripExhaustive3)) { - // Convert from FP8 to supported floating point type. - XlaBuilder builder(this->TestName()); - - using From = tsl::float8_e8m0fnu; - std::vector all_f8; - for (int i = 0; i < 256; i++) { - all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); - } - - ConvertElementType(ConstantR1(&builder, all_f8), - primitive_util::NativeToPrimitiveType()); - this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); -} - -XLA_TYPED_TEST(ConvertTestF16, - DISABLED_ON_TPU(ConvertF8e8m0fnuF16RoundtripExhaustive4)) { - // Convert from (B)F16 to FP8. - XlaBuilder builder(this->TestName()); - - std::vector all_f16; - for (int i = 0; i < 65536; i++) { - all_f16.push_back( - Eigen::numext::bit_cast(static_cast(i))); - } - - ConvertElementType(ConstantR1(&builder, all_f16), F8E8M0FNU); - this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); -} - } // namespace } // namespace xla diff --git a/xla/tools/driver.cc b/xla/tools/driver.cc index d1d6882b6532a..7f0d9c4507a2a 100644 --- a/xla/tools/driver.cc +++ b/xla/tools/driver.cc @@ -121,7 +121,6 @@ enum PrimitiveType { F64, C64, C128, - F4E2M1FN, F8E5M2, F8E4M3, F8E4M3FN, @@ -129,19 +128,17 @@ enum PrimitiveType { F8E5M2FNUZ, F8E4M3FNUZ, F8E3M4, - F8E8M0FNU, }; const std::vector& primitive_strings() { static auto vec = new std::vector( - {"s1", "s2", "s4", "s8", - "s16", "s32", "s64", "u1", - "u2", "u4", "u8", "u16", - "u32", "u64", "f16", "bf16", - "f32", "f64", "c64", "c128", - "f4e2m1fn", "f8e3m4", "f8e4m3", "f8e4m3b11fnuz", - "f8e4m3fn", "f8e4m3fnuz", "f8e5m2", "f8e5m2fnuz", - "f8e8m0fnu"}); + {"s1", "s2", "s4", "s8", + "s16", "s32", "s64", "u1", + "u2", "u4", "u8", "u16", + "u32", "u64", "f16", "bf16", + "f32", "f64", "c64", "c128", + "f8e5m2", "f8e4m3", "f8e4m3fn", "f8e4m3b11fnuz", + "f8e5m2fnuz", "f8e4m3fnuz", "f8e3m4"}); return *vec; } @@ -418,7 +415,6 @@ void Fill(void* buffer, const ArrayShape& shape) { case F64: return FillFloatT(buffer, num_elements); - case F4E2M1FN: case F8E5M2: case F8E4M3: case F8E4M3FN: @@ -426,7 +422,6 @@ void Fill(void* buffer, const ArrayShape& shape) { case F8E5M2FNUZ: case F8E4M3FNUZ: case F8E3M4: - case F8E8M0FNU: case F16: case BF16: case C64: @@ -480,7 +475,6 @@ void Display(const void* buffer, const ArrayShape& shape) { case F64: return DisplayT(buffer, num_elements); - case F4E2M1FN: case F8E5M2: case F8E4M3: case F8E4M3FN: @@ -488,7 +482,6 @@ void Display(const void* buffer, const ArrayShape& shape) { case F8E5M2FNUZ: case F8E4M3FNUZ: case F8E3M4: - case F8E8M0FNU: case F16: case BF16: case C64: diff --git a/xla/tsl/framework/BUILD b/xla/tsl/framework/BUILD index 9185021d40550..9f9a030d868e7 100644 --- a/xla/tsl/framework/BUILD +++ b/xla/tsl/framework/BUILD @@ -339,7 +339,6 @@ cc_library( ]), deps = [ ":numeric_types", - "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:types", ], ) diff --git a/xla/tsl/framework/type_traits.h b/xla/tsl/framework/type_traits.h index 2292ee563db80..f7a9bd7a54bc9 100644 --- a/xla/tsl/framework/type_traits.h +++ b/xla/tsl/framework/type_traits.h @@ -21,7 +21,6 @@ limitations under the License. #include #include "xla/tsl/framework/numeric_types.h" -#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/types.h" namespace tsl { @@ -71,15 +70,13 @@ struct is_simple_type { std::is_trivial::value || std::is_same::value || std::is_same::value || std::is_same::value || is_quantized::value || std::is_same::value || - std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value || - std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value; }; diff --git a/xla/tsl/protobuf/dnn.proto b/xla/tsl/protobuf/dnn.proto index 4a6d8fff6f72c..2ac31005c1662 100644 --- a/xla/tsl/protobuf/dnn.proto +++ b/xla/tsl/protobuf/dnn.proto @@ -24,8 +24,6 @@ enum DataType { kInt64 = 12; kF8E4M3 = 13; kF8E3M4 = 14; - kF4E2M1FN = 15; - kF8E8M0FNU = 16; } // Describes how a convolution input or output layer's data is formatted. diff --git a/xla/tsl/python/lib/core/ml_dtypes.cc b/xla/tsl/python/lib/core/ml_dtypes.cc index a986efb7cca96..e2c5eb295c6b1 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.cc +++ b/xla/tsl/python/lib/core/ml_dtypes.cc @@ -61,8 +61,6 @@ struct MlDtypesInitInfo { numpy_dtypes.bfloat16 = py::dtype::from_args(ml_dtypes.attr("bfloat16")).num(); - numpy_dtypes.float4_e2m1fn = - py::dtype::from_args(ml_dtypes.attr("float4_e2m1fn")).num(); numpy_dtypes.float8_e3m4 = py::dtype::from_args(ml_dtypes.attr("float8_e3m4")).num(); numpy_dtypes.float8_e4m3 = @@ -77,8 +75,6 @@ struct MlDtypesInitInfo { py::dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")).num(); numpy_dtypes.float8_e5m2fnuz = py::dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")).num(); - numpy_dtypes.float8_e8m0fnu = - py::dtype::from_args(ml_dtypes.attr("float8_e8m0fnu")).num(); numpy_dtypes.int4 = py::dtype::from_args(ml_dtypes.attr("int4")).num(); numpy_dtypes.uint4 = py::dtype::from_args(ml_dtypes.attr("uint4")).num(); } catch (const std::exception& e) { @@ -89,7 +85,6 @@ struct MlDtypesInitInfo { // Verify all types were successfully loaded. if (numpy_dtypes.bfloat16 == NPY_NOTYPE || - numpy_dtypes.float4_e2m1fn == NPY_NOTYPE || numpy_dtypes.float8_e3m4 == NPY_NOTYPE || numpy_dtypes.float8_e4m3 == NPY_NOTYPE || numpy_dtypes.float8_e4m3fn == NPY_NOTYPE || @@ -97,7 +92,6 @@ struct MlDtypesInitInfo { numpy_dtypes.float8_e4m3b11fnuz == NPY_NOTYPE || numpy_dtypes.float8_e5m2 == NPY_NOTYPE || numpy_dtypes.float8_e5m2fnuz == NPY_NOTYPE || - numpy_dtypes.float8_e8m0fnu == NPY_NOTYPE || numpy_dtypes.int4 == NPY_NOTYPE || numpy_dtypes.uint4 == NPY_NOTYPE) { init_valid = false; } diff --git a/xla/tsl/python/lib/core/ml_dtypes.h b/xla/tsl/python/lib/core/ml_dtypes.h index 725d844c27bb4..b3aa94e430239 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.h +++ b/xla/tsl/python/lib/core/ml_dtypes.h @@ -24,7 +24,6 @@ namespace ml_dtypes { struct NumpyDtypes { int bfloat16; - int float4_e2m1fn; int float8_e3m4; int float8_e4m3; int float8_e4m3fn; @@ -32,7 +31,6 @@ struct NumpyDtypes { int float8_e4m3fnuz; int float8_e5m2; int float8_e5m2fnuz; - int float8_e8m0fnu; int int4; int uint4; }; diff --git a/xla/types.h b/xla/types.h index b702404601dae..98e3d7c9331ff 100644 --- a/xla/types.h +++ b/xla/types.h @@ -131,32 +131,16 @@ struct make_specialized_signed>> { template using make_specialized_signed_t = typename make_specialized_signed::type; -// has_negative_zero[_v] - template struct has_negative_zero : std::bool_constant::is_iec559> {}; -template <> -struct has_negative_zero : std::bool_constant {}; - template <> struct has_negative_zero : std::bool_constant {}; template inline constexpr bool has_negative_zero_v = has_negative_zero::value; -// has_zero[_v] - -template -struct has_zero : std::bool_constant {}; - -template <> -struct has_zero : std::bool_constant {}; - -template -inline constexpr bool has_zero_v = has_zero::value; - } // namespace xla #endif // XLA_TYPES_H_ diff --git a/xla/util.cc b/xla/util.cc index 023e09342f113..c18435a04c64b 100644 --- a/xla/util.cc +++ b/xla/util.cc @@ -148,7 +148,6 @@ std::string Reindent(absl::string_view original, template static void RoundTripNanPayload(FloatT value, std::string* result) { - static_assert(std::numeric_limits::has_quiet_NaN); static_assert(!std::is_same::value, "RoundTripNanPayload does not support E4M3FN"); static_assert(!std::is_same::value, @@ -175,10 +174,6 @@ static std::string GenericRoundTripFpToString(FloatT value) { static_cast(value)); } -std::string RoundTripFpToString(tsl::float4_e2m1fn value) { - return GenericRoundTripFpToString(value); -} - std::string RoundTripFpToString(tsl::float8_e5m2 value) { std::string result = GenericRoundTripFpToString(value); RoundTripNanPayload(value, &result); @@ -217,11 +212,6 @@ std::string RoundTripFpToString(tsl::float8_e3m4 value) { return result; } -std::string RoundTripFpToString(tsl::float8_e8m0fnu value) { - std::string result = GenericRoundTripFpToString(value); - return result; -} - std::string RoundTripFpToString(bfloat16 value) { std::string result = GenericRoundTripFpToString(value); RoundTripNanPayload(value, &result); diff --git a/xla/util.h b/xla/util.h index a457870939244..959009073e96f 100644 --- a/xla/util.h +++ b/xla/util.h @@ -416,9 +416,6 @@ std::string VectorString(const std::initializer_list& c) { return VectorString>(c); } -// Returns a string which can losslessly round trip to a float4 E2M1FN. -std::string RoundTripFpToString(tsl::float4_e2m1fn value); - // Returns a string which can losslessly round trip to a float8 E5M2. std::string RoundTripFpToString(tsl::float8_e5m2 value); @@ -440,9 +437,6 @@ std::string RoundTripFpToString(tsl::float8_e4m3fnuz value); // Returns a string which can losslessly round trip to a float8 E3M4. std::string RoundTripFpToString(tsl::float8_e3m4 value); -// Returns a string which can losslessly round trip to a float8 E8M0FNU. -std::string RoundTripFpToString(tsl::float8_e8m0fnu value); - // Returns a string which can losslessly round trip to a bfloat. std::string RoundTripFpToString(tsl::bfloat16 value); @@ -658,9 +652,8 @@ template auto SignAndMagnitude(T x) { using BitType = UnsignedIntegerTypeForSizeType; BitType x_abs_bits = Eigen::numext::bit_cast(Eigen::numext::abs(x)); - // Eigen implements the sign value to be either all-zeros (for positive input) - // or all-ones (for negative input). - BitType x_sign = Eigen::numext::bit_cast(Eigen::numext::signbit(x)); + const BitType x_bits = Eigen::numext::bit_cast(x); + const BitType x_sign = x_bits ^ x_abs_bits; if constexpr (!has_negative_zero_v) { // f8e4m3b11, f8e4m3fnuz, and f8e5m2fnuz don't support -0, adjust negative // numbers to fill in the gap. @@ -671,17 +664,12 @@ auto SignAndMagnitude(T x) { return std::make_pair(x_sign, x_abs_bits); } -template <> -inline auto SignAndMagnitude(tsl::float8_e8m0fnu x) { - uint8_t x_bits = Eigen::numext::bit_cast(x); - return std::make_pair(static_cast(0), x_bits); -} - template auto SignAndMagnitudeToTwosComplement(T sign, T magnitude) { static_assert(!std::numeric_limits::is_signed); using SignedType = std::make_signed_t; - return static_cast(magnitude) ^ static_cast(sign); + return static_cast(magnitude) ^ + (static_cast(sign) < 0 ? SignedType{-1} : SignedType{0}); } // Returns the signed magnitude of T. @@ -691,11 +679,6 @@ auto ToSignMagnitude(T input) { return SignAndMagnitudeToTwosComplement(sign, magnitude); } -template <> -inline auto ToSignMagnitude(tsl::float8_e8m0fnu input) { - return Eigen::numext::bit_cast(input); -} - template constexpr int NanPayloadBits() { // Floating point types with signaling NaNs have payloads. diff --git a/xla/util_test.cc b/xla/util_test.cc index f864b3215aa4a..cc2465099c1d9 100644 --- a/xla/util_test.cc +++ b/xla/util_test.cc @@ -206,9 +206,9 @@ namespace { template void TotalOrderHelper(T x, T y) { auto x_sm = ToSignMagnitude(x); + bool x_sign = static_cast(Eigen::numext::signbit(x)); + bool y_sign = static_cast(Eigen::numext::signbit(y)); auto y_sm = ToSignMagnitude(y); - bool x_sign = static_cast(SignAndMagnitude(x).first); - bool y_sign = static_cast(SignAndMagnitude(y).first); if (x_sign && !y_sign) { EXPECT_LT(x_sm, y_sm) << x << " " << y; } @@ -239,18 +239,6 @@ void TotalOrderHelper(T x, T y) { } } // namespace -TEST(UtilTest, TotalOrder_F4E2M1FN) { - for (int a = 0; a < 16; ++a) { - tsl::float4_e2m1fn x = - Eigen::numext::bit_cast(static_cast(a)); - for (int b = 0; b < 16; ++b) { - tsl::float4_e2m1fn y = - Eigen::numext::bit_cast(static_cast(b)); - TotalOrderHelper(x, y); - } - } -} - TEST(UtilTest, TotalOrder_F8E5M2) { for (int a = 0; a < 256; ++a) { tsl::float8_e5m2 x = @@ -337,18 +325,6 @@ TEST(UtilTest, TotalOrder_F8E3M4) { } } -TEST(UtilTest, TotalOrder_F8E8M0FNU) { - for (int a = 0; a < 256; ++a) { - tsl::float8_e8m0fnu x = - Eigen::numext::bit_cast(static_cast(a)); - for (int b = 0; b < 256; ++b) { - tsl::float8_e8m0fnu y = - Eigen::numext::bit_cast(static_cast(b)); - TotalOrderHelper(x, y); - } - } -} - void PackInt4(absl::Span input, absl::Span output) { CHECK_EQ(output.size(), CeilOfRatio(input.size(), size_t{2})); for (size_t i = 0; i < input.size(); ++i) { diff --git a/xla/xla_data.proto b/xla/xla_data.proto index 87a4b3b35c049..82b822f2e3ecb 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -111,17 +111,6 @@ enum PrimitiveType { F8E5M2FNUZ = 24; F8E4M3FNUZ = 25; - // MX float dtypes, as described in: - // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf - // - // F4E2M1FN has 2 exponent bits and 1 mantissa bit. - // F8E8M0FNU has 8 exponent bits, no mantissa and no sign. - // - // Only finite values are supported (hence "FN" suffix). Unlike IEEE types, - // infinities and NaNs are not supported. - F4E2M1FN = 32; - F8E8M0FNU = 33; - // Complex values of fixed width. C64 = 15; // Paired F32 (real, imag), as in std::complex. C128 = 18; // Paired F64 (real, imag), as in std::complex. @@ -147,7 +136,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 34 + // Next = 32 } // LINT.ThenChange( // https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc @@ -592,17 +581,15 @@ message LiteralProto { bytes bf16s = 13; bytes u16s = 16; bytes s16s = 17; - bytes f4e2m1fns = 30; - bytes f8e3m4s = 29; - bytes f8e4m3b11fnuzs = 23; - bytes f8e4m3fns = 20; - bytes f8e4m3fnuzs = 25; + bytes f8e5m2s = 19; bytes f8e4m3s = 28; + bytes f8e4m3fns = 20; + bytes f8e4m3b11fnuzs = 23; bytes f8e5m2fnuzs = 24; - bytes f8e5m2s = 19; - bytes f8e8m0fnus = 31; + bytes f8e4m3fnuzs = 25; + bytes f8e3m4s = 29; repeated int64 sparse_indices = 14; - // Next = 32 + // Next = 30 } message WindowDimension {