diff --git a/xla/hlo/builder/lib/math_test.cc b/xla/hlo/builder/lib/math_test.cc index 68efbe71c662d..1cdf6648b52f9 100644 --- a/xla/hlo/builder/lib/math_test.cc +++ b/xla/hlo/builder/lib/math_test.cc @@ -96,12 +96,12 @@ class MathTypedTest : public MathTest { bool has_inf = std::numeric_limits::has_infinity; bool has_nan = std::numeric_limits::has_quiet_NaN; - bool is_finite = !has_inf && !has_nan; - bool is_nan_only = !has_inf && has_nan; + bool has_finite = !has_inf && !has_nan; + bool has_nan_only = !has_inf && has_nan; auto expected = LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR1({true, true, true, true, true, is_finite, - is_finite, is_finite, is_finite}), + LiteralUtil::CreateR1({true, true, true, true, true, has_finite, + has_finite, has_finite, has_finite}), LiteralUtil::CreateR1({false, false, false, false, false, has_inf, has_inf, false, false}), LiteralUtil::CreateR1( @@ -109,7 +109,7 @@ class MathTypedTest : public MathTest { LiteralUtil::CreateR1( {false, false, false, false, false, false, has_inf, false, false}), LiteralUtil::CreateR1({false, false, false, false, false, - is_nan_only, is_nan_only, has_nan, + has_nan_only, has_nan_only, has_nan, has_nan})); ComputeAndCompareLiteral(&b, expected, {}); } diff --git a/xla/literal.h b/xla/literal.h index cf544aa3a0c0c..adee914c09210 100644 --- a/xla/literal.h +++ b/xla/literal.h @@ -593,13 +593,6 @@ class LiteralBase { static_assert(8 % bits_per_element == 0); constexpr int elements_per_byte = 8 / bits_per_element; - constexpr auto cast = [](NativeT x) -> uint8_t { - if constexpr (primitive_util::IsFloatingPointType(primitive_type)) { - return Eigen::numext::bit_cast(x); - } - return static_cast(x); - }; - int64_t bytes = elements.size() / elements_per_byte; for (int64_t i = 0; i < bytes; ++i) { uint8_t byte = 0; @@ -710,7 +703,6 @@ class LiteralBase { static_assert(!primitive_util::IsComplexType(primitive_type)); static_assert(8 % bits_per_element == 0); - constexpr int elements_per_byte = 8 / bits_per_element; constexpr auto cast = [](uint8_t x) -> NativeT { if constexpr (primitive_util::IsFloatingPointType(primitive_type)) { return Eigen::numext::bit_cast(x); @@ -718,6 +710,7 @@ class LiteralBase { return static_cast(x); }; + constexpr int elements_per_byte = 8 / bits_per_element; int64_t bytes = elements.size() / elements_per_byte; for (int64_t i = 0; i < bytes; ++i) { uint8_t byte; diff --git a/xla/literal_comparison_test.cc b/xla/literal_comparison_test.cc index ee8aafda2c350..29c12eb7c75e4 100644 --- a/xla/literal_comparison_test.cc +++ b/xla/literal_comparison_test.cc @@ -73,7 +73,7 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); float expV = 1.5; // F8E4M3* if (type == F8E5M2 || type == F8E5M2FNUZ) - expV = 1.75; + expV = 2.0; else if (type == F8E3M4) expV = 1.25; else if (type == F4E2M1FN) @@ -99,7 +99,7 @@ TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) { auto actual = LiteralUtil::CreateR0(1.0); float expV = 1.51; // F8E4M3* if (type == F8E5M2 || type == F8E5M2FNUZ) - expV = 1.76; + expV = 2.01; else if (type == F8E3M4) expV = 1.26; else if (type == F4E2M1FN) diff --git a/xla/python/ifrt/dtype.proto b/xla/python/ifrt/dtype.proto index 8440ce792b93a..2cf453f26c291 100644 --- a/xla/python/ifrt/dtype.proto +++ b/xla/python/ifrt/dtype.proto @@ -81,7 +81,7 @@ message DTypeProto { // collision. KIND_STRING = 99; - // Next: 31 + // Next: 32 } // LINT.ThenChange() Kind kind = 1; diff --git a/xla/python/ifrt/dtype_test.cc b/xla/python/ifrt/dtype_test.cc index 55051f0b29cfc..9d3d3105f54e5 100644 --- a/xla/python/ifrt/dtype_test.cc +++ b/xla/python/ifrt/dtype_test.cc @@ -74,7 +74,7 @@ TEST(DTypeTest, BitSize) { {DType::kF8E3M4, 8}, {DType::kF8E4M3, 8}, {DType::kF8E4M3FN, 8}, {DType::kF8E4M3B11FNUZ, 8}, {DType::kF8E4M3FNUZ, 8}, {DType::kF8E5M2, 8}, - {DType::kF8E5M2FNUZ, 8}, {DType::kF8E8M0FNU, 4}, + {DType::kF8E5M2FNUZ, 8}, {DType::kF8E8M0FNU, 8}, {DType::kS16, 16}, {DType::kU16, 16}, {DType::kF16, 16}, {DType::kBF16, 16}, {DType::kS32, 32}, {DType::kU32, 32}, diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index 710d118b8bc8c..18a9228f44316 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -795,6 +795,43 @@ llvm::Value* EmitF8e4m3b11fnuzToF16(llvm::Value* f8_value, absl::StatusOr EmitF16ToF4e2m1fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { + auto i8_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt8Ty(), val); + }; + auto i16_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt16Ty(), val); + }; + constexpr int mantissa_diff = 9; // 10 for F16, 1 for F4 + constexpr int bias_diff = 14; // 15 for F16, 1 for F4 + + // Cast the input value to an integer for bitwise manipulation. + // Get the absolute value of the input (discard the sign). + // f16_bits = bitcast(f16_value, int) + // f16_abs_bits = f16_bits & 0x7FFF + llvm::Value* f16_bits = b->CreateBitCast(f16_value, b->getInt16Ty()); + llvm::Value* f16_abs_bits = b->CreateAnd(f16_bits, i16_const(0x7FFF)); + + // If the input absolute value is >= 7.0 or an infinity, the result saturates + // to max value (6.0). If (0.75 <= input < 1), the result is rounded to 1.0. + // If (0 <= input <= 0.25), the result is rounded to 0.0. + // If the input is NaN, the result is undefined (implemented as minus zero). + // The rest of the cases are handled by the "happy path". + // is_overflow = f16_abs_bits >= 0x1.Cp2 + // is_one = f16_abs_bits >= 0x1.8p-1 (used only if exponent underflows) + // is_zero = f16_abs_bits <= 0x1p-2 (used only if exponent underflows) + // is_nan = f16_abs_bits > 0x7C00 (F16 NaN threshold) + llvm::Value* is_overflow = + b->CreateICmpUGE(f16_abs_bits, i16_const(0x4700)); // 7.0 + llvm::Value* is_one = + b->CreateICmpUGE(f16_abs_bits, i16_const(0x3A00)); // 0.75 + llvm::Value* is_zero = + b->CreateICmpULE(f16_abs_bits, i16_const(0x3400)); // 0.25 + llvm::Value* is_nan = + b->CreateICmpUGT(f16_abs_bits, i16_const(0x7C00)); // inf + + // Truncate the mantissa to 1 bit and the exponent to 3 bits (not 2 bits, as + // the type doesn't have Inf/NaN and can represent unbiased exponent 2). + // This case, as well as the denormal, is handled below. TF_ASSIGN_OR_RETURN( llvm::Value * reduced_precision, EmitReducePrecisionIR( @@ -802,111 +839,173 @@ absl::StatusOr EmitF16ToF4e2m1fn(llvm::Value* f16_value, /*dest_exponent_bits=*/primitive_util::ExponentWidth(F4E2M1FN) + 1, /*dest_mantissa_bits=*/primitive_util::SignificandWidth(F4E2M1FN) - 1, /*quiet_nans=*/false, b)); + + // Cast the reduced precision value to an integer for bitwise manipulation. + // Discard the least significant (9) mantissa bits leaving 1 bit. + // Truncate to + // as_int16 = bitcast(reduced_precision, int) + // as_int8 = as_int16 >> (f16_mantissa - f4_mantissa) llvm::Value* as_int16 = b->CreateBitCast(reduced_precision, b->getInt16Ty()); llvm::Value* as_int8 = - b->CreateTrunc(b->CreateLShr(as_int16, 9), b->getInt8Ty()); + b->CreateTrunc(b->CreateLShr(as_int16, mantissa_diff), b->getInt8Ty()); - // Extract sign, exponent and mantissa from reduced precision value. - auto i8_const = [&](int val) { - return llvm::ConstantInt::get(b->getInt8Ty(), val); - }; + // Get the sign (0 or 1). + // f4_sign = as_int8 >> 6 llvm::Value* f4_sign = b->CreateLShr(as_int8, 6); + + // Get exponent and mantissa bits without the sign. + // Important: the mask is 0x3F (not 0x7F), discard bit #6. + // f4_bits = as_int8 & 0x3F llvm::Value* f4_bits = b->CreateAnd(as_int8, i8_const(0x3F)); - llvm::Value* f4_normal = b->CreateSub(f4_bits, i8_const(28)); - // Special case for exponent overflow. - auto i16_const = [&](int val) { - return llvm::ConstantInt::get(b->getInt16Ty(), val); - }; - llvm::Value* f16_bits = b->CreateAnd( - b->CreateBitCast(f16_value, b->getInt16Ty()), i16_const(0x7FFF)); - llvm::Value* is_overflow = - b->CreateICmpUGE(f16_bits, i16_const(0x4700)); // 7.0 - llvm::Value* is_nan = b->CreateICmpUGT(f16_bits, i16_const(0x7C00)); // inf - llvm::Value* max_or_nan = - b->CreateSelect(is_nan, i8_const(0x8), i8_const(0x7)); - llvm::Value* f4_normal_or_overflow = - b->CreateSelect(is_overflow, max_or_nan, f4_normal); - - // Special case for exponent underflow. + // Convert F16 exponent to F4 exponent by readjusting the exponent bias. + // This produces the "normal" result, i.e. not Inf or NaN or denormal. + // f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa) + constexpr int f4_exponent_offset = bias_diff << 1; + llvm::Value* f4_normal = b->CreateSub(f4_bits, i8_const(f4_exponent_offset)); + + // If the rounding resulted in zero exponent, the value is incorrect. + // This happens when the input is < 1.0 + // is_underflow = f4_normal <= 1 llvm::Value* is_underflow = b->CreateICmpSLE(f4_normal, i8_const(1)); - llvm::Value* is_one = b->CreateICmpUGE(f16_bits, i16_const(0x3A00)); // 0.75 - llvm::Value* is_zero = b->CreateICmpULE(f16_bits, i16_const(0x3400)); // 0.25 - llvm::Value* denorm_or_zero = - b->CreateSelect(is_zero, i8_const(0x0), i8_const(0x1)); - llvm::Value* f4_small = - b->CreateSelect(is_one, i8_const(0x2), denorm_or_zero); - llvm::Value* f4_result = - b->CreateSelect(is_underflow, f4_small, f4_normal_or_overflow); + + // Chain of selects that handles the special cases. + // f4_result = + // is_underflow ? (is_one ? 1.0 : (is_zero ? 0.0 : 0.5)) : + // is_overflow ? (is_nan ? -0.0 : 6.0) : + // f4_normal + llvm::Value* f4_result = b->CreateSelect( + is_underflow, + // If underflow, the input is < 1.0; the result is either 0.0, 0.5 or 1.0 + b->CreateSelect(is_one, i8_const(0x2), + b->CreateSelect(is_zero, i8_const(0x0), i8_const(0x1))), + // If overflow, the input is >= 7.0 or infinity or NaN. + b->CreateSelect(is_overflow, + b->CreateSelect(is_nan, i8_const(0x8), i8_const(0x7)), + f4_normal)); // Add sign to the resulting value. + // f4_signed_result = (f4_sign << 3) | f4_result return b->CreateOr(f4_result, b->CreateShl(f4_sign, 3)); } llvm::Value* EmitF4e2m1fnToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { - llvm::Value* as_int16 = b->CreateZExt(f8_value, b->getInt16Ty()); - - // Extract sign, exponent and mantissa from reduced precision value. auto i16_const = [&](int val) { return llvm::ConstantInt::get(b->getInt16Ty(), val); }; - llvm::Value* sign = b->CreateLShr(as_int16, 3); - llvm::Value* sign_shifted = b->CreateShl(sign, 15); - llvm::Value* bits = b->CreateAnd(as_int16, i16_const(0x7)); - llvm::Value* bits_shifted = b->CreateShl(bits, 9); - - // Re-bias the exponent and handle denormals. - llvm::Value* f16_normal = b->CreateAdd(bits_shifted, i16_const(14 << 10)); - llvm::Value* is_denorm_or_zero = b->CreateICmpULE(bits, i16_const(1)); - llvm::Value* is_zero = b->CreateICmpEQ(bits, i16_const(0)); - llvm::Value* denorm_or_zero = - b->CreateSelect(is_zero, i16_const(0x0000), i16_const(0x3800)); - llvm::Value* f16_result = - b->CreateSelect(is_denorm_or_zero, denorm_or_zero, f16_normal); + constexpr int mantissa_diff = 9; // 10 for F16, 1 for F4 + constexpr int bias_diff = 14; // 15 for F16, 1 for F4 + + // The input value is a 8-bit integer, extend it to 16-bit integer. + // as_int16 = bitcast(f8_value, int) + llvm::Value* as_int16 = b->CreateZExt(f8_value, b->getInt16Ty()); + + // Get the sign and shift it to F16 position. + // f4_sign = as_int16 >> 3 + // f16_sign_bit = f4_sign << 15 + llvm::Value* f4_sign = b->CreateLShr(as_int16, 3); + llvm::Value* f16_sign_bit = b->CreateShl(f4_sign, 15); + + // Get exponent and mantissa bits without the sign. + // f4_bits = as_int16 & 0x7 + // f16_bits = f4_bits << (f16_mantissa - f4_mantissa) + llvm::Value* f4_bits = b->CreateAnd(as_int16, i16_const(0x7)); + llvm::Value* f16_bits = b->CreateShl(f4_bits, mantissa_diff); + + // Convert F16 exponent to F4 exponent by readjusting the exponent bias. + // f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa) + constexpr int f16_exponent_offset = bias_diff << 10; + llvm::Value* f16_normal = + b->CreateAdd(f16_bits, i16_const(f16_exponent_offset)); + + // For denormal and zero, the exponent is different. Handle these cases + // separately below. + // is_denorm_or_zero = f4_bits <= 1 + // is_zero = f4_bits == 0 + llvm::Value* is_denorm_or_zero = b->CreateICmpULE(f4_bits, i16_const(1)); + llvm::Value* is_zero = b->CreateICmpEQ(f4_bits, i16_const(0)); + + // Chain of selects that handles the special cases. + // f16_result = is_denorm_or_zero ? (is_zero ? 0.0 : 0.5) : f16_normal + llvm::Value* f16_result = b->CreateSelect( + is_denorm_or_zero, + b->CreateSelect(is_zero, i16_const(0x0000), i16_const(0x3800)), + f16_normal); // Add sign to the resulting value. - llvm::Value* f16_signed = b->CreateOr(f16_result, sign_shifted); - return b->CreateBitCast(f16_signed, b->getHalfTy()); + // f16_signed_result = f16_sign_bit | f16_result + llvm::Value* f16_signed_result = b->CreateOr(f16_result, f16_sign_bit); + return b->CreateBitCast(f16_signed_result, b->getHalfTy()); } llvm::Value* EmitF32ToF8e8m0fnu(llvm::Value* f32_value, llvm::IRBuilder<>* b) { - llvm::Value* as_int32 = b->CreateBitCast(f32_value, b->getInt32Ty()); - - // Result is NaN if input is zero, negative, infinity or NaN. auto i32_const = [&](int val) { return llvm::ConstantInt::get(b->getInt32Ty(), val); }; - llvm::Value* is_denorm = b->CreateICmpULE(as_int32, i32_const(0x400000)); - llvm::Value* is_nan = - b->CreateOr(b->CreateICmpSLE(as_int32, i32_const(0)), - b->CreateICmpSGE(as_int32, i32_const(0x7F400000))); - // Round the value and extract exponent. - llvm::Value* rounded = b->CreateAdd(as_int32, i32_const(0x400000)); - llvm::Value* shifted = b->CreateAShr(rounded, 23); - llvm::Value* finite = b->CreateSelect(is_denorm, i32_const(0x00), shifted); - llvm::Value* f32_result = b->CreateSelect(is_nan, i32_const(0xFF), finite); + // Cast the input value to an integer for bitwise manipulation. + // as_int32 = bitcast(f32_value, int) + llvm::Value* as_int32 = b->CreateBitCast(f32_value, b->getInt32Ty()); + + // Check if the input is zero, negative, overflow, infinity or NaN. + // All of these cases cannot be represented in the E8M0 format. + // is_zero_or_negative = as_int32 <= 0 + // is_overflow_or_nan = as_int32 >= 0x1.8p127 + // is_nan = is_zero_or_negative | is_overflow_or_nan + llvm::Value* is_zero_or_negative = b->CreateICmpSLE(as_int32, i32_const(0)); + llvm::Value* is_overflow_or_nan = + b->CreateICmpSGE(as_int32, i32_const(0x7F400000)); // 1.5 * 2^127 + llvm::Value* is_nan = b->CreateOr(is_zero_or_negative, is_overflow_or_nan); + + // Check if the input is a denormal which should round to the minimum value + // (2^-127), as there is no zero value. + // is_denorm = as_int32 <= 0x1p-127 + llvm::Value* is_denorm = + b->CreateICmpULE(as_int32, i32_const(0x400000)); // 1.0 * 2^-127 + + // Round the value (always up) and discard the mantissa. + // rounded = as_int32 + 0x1p-127 + // f8_normal = as_int32 >> f32_mantissa + llvm::Value* rounded = + b->CreateAdd(as_int32, i32_const(0x400000)); // 1.0 * 2^-127 + llvm::Value* f8_normal = b->CreateAShr(rounded, 23); + + // Chain of selects that handles the special cases. + // f8_result = is_nan ? 0xFF : (is_denorm ? 0x00 : f8_normal) + llvm::Value* f8_result = + b->CreateSelect(is_nan, i32_const(0xFF), + b->CreateSelect(is_denorm, i32_const(0x00), f8_normal)); // Truncate to the result type. - return b->CreateTrunc(f32_result, b->getInt8Ty()); + return b->CreateTrunc(f8_result, b->getInt8Ty()); } llvm::Value* EmitF8e8m0fnuToF32(llvm::Value* f8_value, llvm::IRBuilder<>* b) { - // Shift exponent to the left for the normal case. - llvm::Value* as_int32 = b->CreateZExt(f8_value, b->getInt32Ty()); - llvm::Value* shifted = b->CreateShl(as_int32, 23); - - // Special values for 0x00 (denorm) and 0xFF (NaN). auto i32_const = [&](int val) { return llvm::ConstantInt::get(b->getInt32Ty(), val); }; + + // The input value is a 8-bit integer, extend it to 32-bit integer. + // as_int32 = bitcast(f8_value, int) + llvm::Value* as_int32 = b->CreateZExt(f8_value, b->getInt32Ty()); + + // Check if the input is a denormal or NaN. + // is_zero = as_int32 == 0x00 + // is_nan = as_int32 == 0xFF llvm::Value* is_zero = b->CreateICmpEQ(as_int32, i32_const(0)); llvm::Value* is_nan = b->CreateICmpEQ(as_int32, i32_const(0xFF)); - llvm::Value* denorm_or_shifted = - b->CreateSelect(is_zero, i32_const(0x00400000), shifted); - llvm::Value* f32_result = - b->CreateSelect(is_nan, i32_const(0x7FC00000), denorm_or_shifted); + // Shift exponent to the left for the normal case. + // f32_normal = as_int32 << mantissa_diff + llvm::Value* f32_normal = b->CreateShl(as_int32, 23); + + // Chain of selects that handles the special cases. + // f32_result = is_nan ? 0x7FC00000 : (is_zero ? 0x1p-127 : f32_normal) + llvm::Value* f32_result = b->CreateSelect( + is_nan, i32_const(0x7FC00000), + b->CreateSelect(is_zero, i32_const(0x400000), f32_normal)); + + // Bitcast integer bits to the result type. return b->CreateBitCast(f32_result, b->getFloatTy()); } diff --git a/xla/service/gpu/fusions/transforms/lower_tensors.cc b/xla/service/gpu/fusions/transforms/lower_tensors.cc index 00a980f28e81c..61cb741b40137 100644 --- a/xla/service/gpu/fusions/transforms/lower_tensors.cc +++ b/xla/service/gpu/fusions/transforms/lower_tensors.cc @@ -495,21 +495,18 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value, // Needed to support complex element type. mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_element_type = converter.convertType(element_type); - if (mlir::isa(element_type)) { - int bit_width = mlir::cast(element_type).getWidth(); - if (bit_width == 4) { - num_elements = CeilOfRatio(num_elements, 2); - llvm_element_type = b.getI8Type(); - auto unpacked_data = - mlir::cast(value).getRawData(); - std::vector packed_data(num_elements); - absl::Span packed_data_span = - absl::MakeSpan(packed_data.data(), packed_data.size()); - PackIntN(4, unpacked_data, packed_data_span); - value = mlir::DenseElementsAttr::getFromRawBuffer( - mlir::RankedTensorType::get({num_elements}, llvm_element_type), - packed_data); - } + if (element_type.getIntOrFloatBitWidth() == 4) { + num_elements = CeilOfRatio(num_elements, 2); + llvm_element_type = b.getI8Type(); + auto unpacked_data = + mlir::cast(value).getRawData(); + std::vector packed_data(num_elements); + absl::Span packed_data_span = + absl::MakeSpan(packed_data.data(), packed_data.size()); + PackIntN(4, unpacked_data, packed_data_span); + value = mlir::DenseElementsAttr::getFromRawBuffer( + mlir::RankedTensorType::get({num_elements}, llvm_element_type), + packed_data); } auto array_ty = mlir::LLVM::LLVMArrayType::get(llvm_element_type, num_elements); diff --git a/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir b/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir index 69377549340b7..3074483a77b36 100644 --- a/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir +++ b/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir @@ -743,4 +743,44 @@ func.func @complex_expm1_approx(%arg0: tensor<3xcomplex>, %i: index) } // CHECK-LABEL: @complex_expm1_approx // CHECK: math.expm1 -// CHECK-COUNT-6: math.fma \ No newline at end of file +// CHECK-COUNT-6: math.fma + +// ----- + +func.func @f4_constant(%arg0: tensor<3xf4E2M1FN>, %arg1: index) -> f4E2M1FN { + %cst = arith.constant dense<[0.5, -0.5, 2.5]> : tensor<3xf4E2M1FN> + %extracted = tensor.extract %arg0[%arg1] : tensor<3xf4E2M1FN> + %extracted_0 = tensor.extract %cst[%arg1] : tensor<3xf4E2M1FN> + %0 = arith.addf %extracted, %extracted_0 : f4E2M1FN + return %0 : f4E2M1FN +} +// CHECK: llvm.mlir.global private constant +// CHECK-SAME: dense<[25, 64]> +// CHECK-LABEL: @f4_constant + +// ----- + +func.func @transfer_read_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}) -> vector<2xf4E2M1FN> { + %c16 = arith.constant 16 : index + %c0 = arith.constant 0.0 : f4E2M1FN + %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf4E2M1FN>, vector<2xf4E2M1FN> + func.return %out : vector<2xf4E2M1FN> +} +// CHECK-LABEL: @transfer_read_f4 +// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[8] +// CHECK: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xi4> +// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xi4> to vector<2xf4E2M1FN> +// CHECK: return %[[OUT]] : vector<2xf4E2M1FN> + +// ----- + +func.func @transfer_write_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}, + %arg1: vector<2xf4E2M1FN>) -> tensor<43xf4E2M1FN> { + %c10 = arith.constant 10 : index + %out = vector.transfer_write %arg1, %arg0[%c10] : vector<2xf4E2M1FN>, tensor<43xf4E2M1FN> + func.return %out : tensor<43xf4E2M1FN> +} +// CHECK-LABEL: @transfer_write_f4 +// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %arg0[5] : (!llvm.ptr) -> !llvm.ptr, i8 +// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xf4E2M1FN> to vector<2xi4> +// CHECK: llvm.store %[[OUT]], %[[PTR]] : vector<2xi4>, !llvm.ptr diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index 8d7e32b05e3c5..59766747e2ca7 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -1957,8 +1957,7 @@ XLA_TEST_F(ConvertTest, ConvertF16F4e2m1fnRoundtrip) { // Denormal tests {0x1.0p-1, 0x1.0p-1}, // Denormal without rounding - {0x1.4p-1, 0x1.0p-1}, // Round-to-even down - {0x1.Cp-1, 0x1.0p0}, // Round-to-even up + {0x1.8p-1, 0x1.0p0}, // Round-to-even up {0x1.6p-1, 0x1.0p-1}, // Round-to-nearest down {0x1.Ep-1, 0x1.0p0}, // Round-to-nearest up {0x1p-2, 0}, // Largest number that underflows @@ -2007,8 +2006,7 @@ XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F4e2m1fnRoundtrip)) { // Denormal tests {0x1.0p-1, 0x1.0p-1}, // Denormal without rounding - {0x1.4p-1, 0x1.0p-1}, // Round-to-even down - {0x1.Cp-1, 0x1.0p0}, // Round-to-even up + {0x1.8p-1, 0x1.0p0}, // Round-to-even up {0x1.6p-1, 0x1.0p-1}, // Round-to-nearest down {0x1.Ep-1, 0x1.0p0}, // Round-to-nearest up {0x1p-2, 0}, // Largest number that underflows @@ -2109,7 +2107,7 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e8m0fnuRoundtrip) { {inf, nan}, // clang-format on {0x1.8p1, 0x1p2}, // Round-to-even up - {0x1.8p2, 0x1p3}, // Round-to-even up (not a mistake) + {0x1.8p2, 0x1p3}, // Round-to-even up (always rounds up) {0x1p127, 0x1p127}, // Max value {0x1.7FFFFEp127, 0x1p127}, // Largest number that doesn't overflow {0x1.8p127, nan}, // Smallest number that overflows @@ -2147,11 +2145,16 @@ XLA_TYPED_TEST(ConvertTestT, ConvertF8e8m0fnuRoundtripExhaustive) { this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -// This test is disabled on CPU, as converting 0x1p-127 from double to float -// using CVTSD2SS on x64 results in an underflow (even though the result is -// representable as denormalized float32). -XLA_TYPED_TEST(ConvertTestT, - DISABLED_ON_CPU(ConvertF8e8m0fnuRoundtripExhaustive2)) { +XLA_TYPED_TEST(ConvertTestT, ConvertF8e8m0fnuRoundtripExhaustive2) { +#ifdef XLA_TEST_BACKEND_CPU + // This test is disabled on CPU, as converting 0x1p-127 from double to float + // using CVTSD2SS on x64 results in an underflow (even though the result is + // representable as denormalized float32). + if (std::is_same_v) { + GTEST_SKIP() << "Skipping test for double precision floating point that " + "loses denormal value during conversion"; + } +#endif // Convert from supported floating point type to FP8. XlaBuilder builder(this->TestName()); diff --git a/xla/util.h b/xla/util.h index 6e3d60a7d0dea..a5fea9e787b7f 100644 --- a/xla/util.h +++ b/xla/util.h @@ -658,6 +658,8 @@ template auto SignAndMagnitude(T x) { using BitType = UnsignedIntegerTypeForSizeType; BitType x_abs_bits = Eigen::numext::bit_cast(Eigen::numext::abs(x)); + // Eigen implements the sign value to be either all-zeros (for positive input) + // or all-ones (for negative input). BitType x_sign = Eigen::numext::bit_cast(Eigen::numext::signbit(x)); if constexpr (!has_negative_zero_v) { // f8e4m3b11, f8e4m3fnuz, and f8e5m2fnuz don't support -0, adjust negative