Skip to content

Commit

Permalink
Add F8E8M0FNU type
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-kozub committed Dec 18, 2024
1 parent d7d5af7 commit 9e8c7bc
Show file tree
Hide file tree
Showing 71 changed files with 810 additions and 140 deletions.
99 changes: 99 additions & 0 deletions third_party/tsl/third_party/py/ml_dtypes/e8m0.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
diff --git a/include/float8.h b/include/float8.h
index 51d91fd..c53eb45 100644
--- a/include/float8.h
+++ b/include/float8.h
@@ -1021,11 +1021,11 @@ struct numeric_limits_float8_e8m0fnu : public numeric_limits_float8_base {
static inline constexpr const int max_digits10 =
MaxDigits10FromDigits(digits);
// 2**-127 smallest valid normalized value..
- static inline constexpr const int min_exponent = -127 + 1;
+ static inline constexpr const int min_exponent = -kExponentBias + 1;
static inline constexpr const int min_exponent10 =
MinExponent10FromMinExponent(min_exponent);
// 128 encoding using for NaN
- static inline constexpr const int max_exponent = 127;
+ static inline constexpr const int max_exponent = kExponentBias + 1;
static inline constexpr const int max_exponent10 =
MaxExponent10FromMaxExponentAndDigits(max_exponent, digits);
static inline constexpr const bool is_iec559 = false;
@@ -1292,7 +1292,8 @@ struct Traits<float8_e8m0fnu> : public TraitsBase<float8_e8m0fnu> {
};

template <typename Bits>
-constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff) {
+constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff,
+ bool use_implicit_bit) {
// Round to nearest even by adding a bias term.
// Consider a bit pattern
// FFF...FLRTT...T,
@@ -1301,9 +1302,12 @@ constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff) {
// - L is 1, R is 1, OR
// - L is 0, R is 1, any T is one.
// We do this by adding L to a bit pattern consisting of all T = 1.
- Bits bias = roundoff == 0
- ? 0
- : ((bits >> roundoff) & 1) + (Bits{1} << (roundoff - 1)) - 1;
+ //
+ // When rounding to zero mantissa (E8M0 type), the L bit is implicitly 1 (do
+ // not use the exponent bits for rounding). Add only the R bit in this case.
+ Bits bias = !use_implicit_bit
+ ? ((bits >> roundoff) & 1) + (Bits{1} << (roundoff - 1)) - 1
+ : Bits{1} << (roundoff - 1);
return bits + bias;
}

@@ -1443,6 +1447,7 @@ struct ConvertImpl<From, To, kSaturate, kTruncate,
}

const int biased_from_exponent = from_bits >> kFromMantissaBits;
+ const bool to_zero_mantissa = kToMantissaBits == 0;

// `To` supports more exponents near zero which means that some subnormal
// values in `From` may become normal.
@@ -1473,11 +1478,14 @@ struct ConvertImpl<From, To, kSaturate, kTruncate,
}

// Truncate/round mantissa if necessary.
- if constexpr (kDigitShift > 0) {
+ if constexpr (kDigitShift >= 0) {
bits <<= kDigitShift;
} else {
if constexpr (!kTruncate) {
- bits = RoundBitsToNearestEven(bits, -kDigitShift);
+ // When converting float to e8m0, the bits represent a denormal,
+ // so don't use the implicit mantissa bit for rounding.
+ bits = RoundBitsToNearestEven(
+ bits, -kDigitShift, to_zero_mantissa && kExponentOffset != 0);
}
bits >>= -kDigitShift;
}
@@ -1514,8 +1522,8 @@ struct ConvertImpl<From, To, kSaturate, kTruncate,
// otherwise the lower precision bits may already be lost. There
// is an edge-case where rounding to a normalized value would
// normally round down, but for a subnormal, we need to round up.
- rounded_from_bits =
- RoundBitsToNearestEven(rounded_from_bits, exponent_shift);
+ rounded_from_bits = RoundBitsToNearestEven(rounded_from_bits,
+ exponent_shift, false);
}
bits = rounded_from_bits >> exponent_shift;
}
@@ -1532,7 +1540,8 @@ struct ConvertImpl<From, To, kSaturate, kTruncate,
WideBits rounded_from_bits = from_bits;
if constexpr (kDigitShift < 0) {
if constexpr (!kTruncate) {
- rounded_from_bits = RoundBitsToNearestEven(from_bits, -kDigitShift);
+ rounded_from_bits =
+ RoundBitsToNearestEven(from_bits, -kDigitShift, to_zero_mantissa);
}
// Zero-out tail bits.
rounded_from_bits &= ~((WideBits{1} << (-kDigitShift)) - 1);
@@ -1602,7 +1611,7 @@ struct ConvertImpl<Eigen::half, float8_e5m2, kSaturate, kTruncate> {
}

if constexpr (!kTruncate) {
- from_bits = RoundBitsToNearestEven(from_bits, 8);
+ from_bits = RoundBitsToNearestEven(from_bits, 8, false);
// Rounding can cause an overflow to infinity. Clamp to the largest finite
// value if saturation is requested.
if constexpr (kSaturate) {
1 change: 1 addition & 0 deletions third_party/tsl/third_party/py/ml_dtypes/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def repo():
tf_http_archive(
name = "ml_dtypes",
build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD",
patch_file = ["//third_party/py/ml_dtypes:e8m0.patch"],
link_files = {
"//third_party/py/ml_dtypes:ml_dtypes.tests.BUILD": "tests/BUILD.bazel",
"//third_party/py/ml_dtypes:LICENSE": "LICENSE",
Expand Down
1 change: 1 addition & 0 deletions third_party/tsl/tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
using float8_e5m2 = ::ml_dtypes::float8_e5m2;
using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz;
using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu;

using int1 = ::ml_dtypes::int1;
using uint1 = ::ml_dtypes::uint1;
Expand Down
14 changes: 14 additions & 0 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,20 @@ TEST(Array2dTest, LinspaceF4E2M1FN) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, LinspaceF8E8M0FNU) {
auto arr = MakeLinspaceArray2D<tsl::float8_e8m0fnu>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 2.0); // 1.5 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 4.0); // 3.0 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, Stringification) {
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
const std::string expected = R"([[1, 1.5],
Expand Down
82 changes: 59 additions & 23 deletions xla/backends/gpu/codegen/transforms/expand_float_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ int GetSignificandBits(mlir::FloatType ty) {
}

int GetExponentBias(mlir::FloatType ty) {
return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics());
return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()) -
ty.isFloat8E8M0FNU(); // No zero exponent for E8M0.
}

bool IsFNUZ(mlir::FloatType ty) {
Expand Down Expand Up @@ -215,6 +216,8 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) {
return (bits & 0b0111'1111) == 0b0111'1111;
} else if (ty.isFloat8E3M4()) {
return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'0000);
} else if (ty.isFloat8E8M0FNU()) {
return bits == 0xFF;
}
return bits == 0x80;
}
Expand Down Expand Up @@ -294,6 +297,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
} else {
wide_int_ty = b.getIntegerType(
std::max(from_int_ty.getWidth(), to_int_ty.getWidth()));
// Avoid overflow for bit shifts.
auto may_overflow = [&](mlir::Type a, mlir::Type b) {
return a.isFloat8E8M0FNU() && b.isF16();
};
if (may_overflow(from_ty, to_ty) || may_overflow(to_ty, from_ty)) {
wide_int_ty = b.getI32Type();
}
}
auto convert_int = [&](mlir::Type ty, Value v) -> Val {
if (v.getType() == ty) {
Expand All @@ -320,11 +330,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
};

// Shift bits to destination type, without sign bit.
Val from_sign_bit = from_bits.shrui(from_width - 1) != 0;
from_bits = from_bits & ((1ULL << (from_width - 1)) - 1);

Value result_is_inf = IsInf(value, b);
Value input_is_nan = IsNaN(value, b);
Val from_sign_bit;
if (!from_ty.isFloat8E8M0FNU()) {
from_sign_bit = from_bits.shrui(from_width - 1) != 0;
from_bits = from_bits & ((1ULL << (from_width - 1)) - 1);
}

auto cst_bits = [&](llvm::APFloat f) {
return cst(b.getIntegerType(llvm::APFloat::getSizeInBits(f.getSemantics())),
Expand All @@ -338,9 +348,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
if (to_ty.isFloat4E2M1FN()) {
to_inf = cst_bits(llvm::APFloat::getLargest(to_ty.getFloatSemantics()));
to_nan = to_zero | 0x8;
} else if (to_ty.isFloat8E8M0FNU()) {
to_inf = to_nan;
to_zero = Val{to_nan, &b};
}

auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) {
auto round_bits_to_nearest_even = [&](Val bits, Val roundoff,
bool use_implicit_bit = false) {
assert(bits.value.getType() == roundoff.value.getType());
// Round to nearest even by adding a bias term.
// Consider a bit pattern
Expand All @@ -350,9 +364,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
// - L is 1, R is 1, OR
// - L is 0, R is 1, any T is one.
// We do this by adding L to a bit pattern consisting of all T = 1.
Val rounded = (bits.shrui(roundoff) & 1) +
(bits.MakeConstant(1).shl(roundoff - 1) - 1);
Val bias{b.create<SelectOp>(roundoff == 0, roundoff, rounded), &b};
Val bias = !use_implicit_bit
? (bits.shrui(roundoff) & 1) +
(bits.MakeConstant(1).shl(roundoff - 1) - 1)
: bits.MakeConstant(1).shl(roundoff - 1);
return bits + bias;
};

Expand All @@ -362,9 +377,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
// Round the mantissa if it is shrinking.
Val rounded_from_bits = convert_int(wide_int_ty, from_bits);
if (digit_shift < 0) {
rounded_from_bits = round_bits_to_nearest_even(
from_bits, from_bits.MakeConstant(-digit_shift)) &
~((1ll << (-digit_shift)) - 1);
rounded_from_bits =
round_bits_to_nearest_even(
rounded_from_bits, rounded_from_bits.MakeConstant(-digit_shift),
/*use_implicit_bit=*/to_mantissa == 0) &
~((1ll << (-digit_shift)) - 1);
}

// Re-bias the exponent.
Expand Down Expand Up @@ -431,10 +448,12 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
Value biased_exp_sle_zero = biased_exponent.cmp(CmpIPredicate::sle, 0);
bits.value =
b.create<SelectOp>(biased_exp_sle_zero, subnormal_bits, normal_bits);
if (digit_shift > 0) {
if (digit_shift >= 0) {
bits = bits.shl(digit_shift);
} else {
bits = round_bits_to_nearest_even(bits, bits.MakeConstant(-digit_shift));
bits = round_bits_to_nearest_even(
bits, bits.MakeConstant(-digit_shift),
/*use_implicit_bit=*/to_mantissa == 0 && exp_offset != 0);
bits = bits.shrui(-digit_shift);
}
bits = convert_int(to_int_ty, bits);
Expand All @@ -443,11 +462,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
} else if (to_min_exp > from_min_exp) {
// `To` supports fewer exponents near zero which means that some values in
// `From` may become subnormal.
Val unbiased_exp = biased_from_exp - from_bias;
Val biased_to_exp = unbiased_exp + to_bias;
Val biased_to_exp = biased_from_exp + (to_bias - from_bias);
// Subnormals and zero.
// Round and shift mantissa down.
Val from_has_leading_one = biased_from_exp != 0;
Val from_has_leading_one =
!from_ty.isFloat8E8M0FNU() ? biased_from_exp != 0 : cst(i32_ty, 1);
Val from_has_leading_one_i32 = convert_int(i32_ty, from_has_leading_one);
from_has_leading_one = convert_int(from_int_ty, from_has_leading_one);
Val exponent_shift_i32 =
Expand Down Expand Up @@ -482,7 +501,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
result);
}

if (IsFNUZ(to_ty)) {
Value result_is_inf = IsInf(value, b);
Value input_is_nan = IsNaN(value, b);

if (to_ty.isFloat8E8M0FNU()) {
// Converting a negative number to E8M0 results in NaN.
input_is_nan = from_sign_bit | input_is_nan;
} else if (IsFNUZ(to_ty)) {
// Clear the sign bit if the result is zero (the output has no negative
// zero). Handle the edge case when the input is zero and the result is not.
Val result_is_non_zero =
Expand All @@ -494,14 +519,17 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
from_sign_bit = from_sign_bit ^ input_is_nan;
}

if (!from_ty.isFloat8E8M0FNU()) {
result = b.create<SelectOp>(from_bits == 0, to_zero, result);
}
result = b.create<SelectOp>(result_is_inf, to_inf, result);
result = b.create<SelectOp>(from_bits == 0, to_zero, result);
result = b.create<SelectOp>(input_is_nan, to_nan, result);

Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1));

// Insert sign bit.
result = b.create<SelectOp>(from_sign_bit, neg_result, result);
if (!from_ty.isFloat8E8M0FNU()) {
Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1));
result = b.create<SelectOp>(from_sign_bit, neg_result, result);
}
result = b.create<ma::BitcastOp>(to_ty, result);
return result;
}
Expand Down Expand Up @@ -598,6 +626,14 @@ struct RewriteAbsFPattern : public mlir::OpRewritePattern<mlir::math::AbsFOp> {
return rewriter.notifyMatchFailure(op,
"not an f8 (or less) or bf16 absf");
}

// If type is unsigned (E8M0), the operation is no-op.
if (!llvm::APFloat::semanticsHasSignedRepr(
src.getType().getFloatSemantics())) {
rewriter.replaceAllOpUsesWith(op, op.getOperand());
return mlir::success();
}

mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth());
Val value{b.create<ma::BitcastOp>(i_ty, src), &b};
Expand Down
13 changes: 13 additions & 0 deletions xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,16 @@ module {
// CHECK-LABEL: @f4_abs
// CHECK-NOT: math.absf
// CHECK: arith.constant 7 : i4

// -----

module {
func.func @e8m0_abs(%arg0: f8E8M0FNU) -> f8E8M0FNU {
%ret = math.absf %arg0 : f8E8M0FNU
return %ret : f8E8M0FNU
}
}

// CHECK-LABEL: @e8m0_abs
// CHECK-NOT: math.absf
// CHECK: return %arg0
9 changes: 7 additions & 2 deletions xla/comparison_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,13 @@ class Comparison {
// -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN
// Reference:
// https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
if constexpr (std::numeric_limits<T>::is_signed) {
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
} else {
using R = UnsignedIntegerTypeForSizeType<sizeof(T)>;
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
}
}
}
// Applies the comparison from this Comparison's direction and ordering.
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ inline std::ostream& operator<<(std::ostream& os,
return os << "F8E5M2FNUZ";
case XLA_FFI_DataType_F8E4M3FNUZ:
return os << "F8E4M3FNUZ";
case XLA_FFI_DataType_F8E8M0FNU:
return os << "F8E8M0FNU";
}
}

Expand Down
1 change: 1 addition & 0 deletions xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ typedef enum {
XLA_FFI_DataType_F8E5M2FNUZ = 24,
XLA_FFI_DataType_F8E4M3FNUZ = 25,
XLA_FFI_DataType_F4E2M1FN = 30,
XLA_FFI_DataType_F8E8M0FNU = 31,
} XLA_FFI_DataType;
// LINT.ThenChange(ffi_test.cc)

Expand Down
3 changes: 3 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ enum class DataType : uint8_t {
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
F8E3M4 = XLA_FFI_DataType_F8E3M4,
F4E2M1FN = XLA_FFI_DataType_F4E2M1FN,
F8E8M0FNU = XLA_FFI_DataType_F8E8M0FNU,
};

// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
Expand Down Expand Up @@ -108,6 +109,7 @@ inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
inline constexpr DataType F8E3M4 = DataType::F8E3M4;
inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN;
inline constexpr DataType F8E8M0FNU = DataType::F8E8M0FNU;

inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
return os << static_cast<XLA_FFI_DataType>(dtype);
Expand All @@ -130,6 +132,7 @@ constexpr size_t ByteWidth(DataType dtype) {
case DataType::F8E4M3FNUZ:
case DataType::F8E3M4:
case DataType::F4E2M1FN:
case DataType::F8E8M0FNU:
return 1;
case DataType::S16:
case DataType::U16:
Expand Down
Loading

0 comments on commit 9e8c7bc

Please sign in to comment.