Skip to content

Commit

Permalink
Add F8E8M0FNU type
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-kozub committed Dec 3, 2024
1 parent 87da2eb commit e0ee48c
Show file tree
Hide file tree
Showing 71 changed files with 809 additions and 142 deletions.
99 changes: 99 additions & 0 deletions third_party/tsl/third_party/py/ml_dtypes/e8m0.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
diff --git a/include/float8.h b/include/float8.h
index 51d91fd..c53eb45 100644
--- a/include/float8.h
+++ b/include/float8.h
@@ -1021,11 +1021,11 @@ struct numeric_limits_float8_e8m0fnu : public numeric_limits_float8_base {
static inline constexpr const int max_digits10 =
MaxDigits10FromDigits(digits);
// 2**-127 smallest valid normalized value..
- static inline constexpr const int min_exponent = -127 + 1;
+ static inline constexpr const int min_exponent = -kExponentBias + 1;
static inline constexpr const int min_exponent10 =
MinExponent10FromMinExponent(min_exponent);
// 128 encoding using for NaN
- static inline constexpr const int max_exponent = 127;
+ static inline constexpr const int max_exponent = kExponentBias + 1;
static inline constexpr const int max_exponent10 =
MaxExponent10FromMaxExponentAndDigits(max_exponent, digits);
static inline constexpr const bool is_iec559 = false;
@@ -1292,7 +1292,8 @@ struct Traits<float8_e8m0fnu> : public TraitsBase<float8_e8m0fnu> {
};

template <typename Bits>
-constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff) {
+constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff,
+ bool use_implicit_bit) {
// Round to nearest even by adding a bias term.
// Consider a bit pattern
// FFF...FLRTT...T,
@@ -1301,9 +1302,12 @@ constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff) {
// - L is 1, R is 1, OR
// - L is 0, R is 1, any T is one.
// We do this by adding L to a bit pattern consisting of all T = 1.
- Bits bias = roundoff == 0
- ? 0
- : ((bits >> roundoff) & 1) + (Bits{1} << (roundoff - 1)) - 1;
+ //
+ // When rounding to zero mantissa (E8M0 type), the L bit is implicitly 1 (do
+ // not use the exponent bits for rounding). Add only the R bit in this case.
+ Bits bias = !use_implicit_bit
+ ? ((bits >> roundoff) & 1) + (Bits{1} << (roundoff - 1)) - 1
+ : Bits{1} << (roundoff - 1);
return bits + bias;
}

@@ -1443,6 +1447,7 @@ struct ConvertImpl<From, To, kSaturate, kTruncate,
}

const int biased_from_exponent = from_bits >> kFromMantissaBits;
+ const bool to_zero_mantissa = kToMantissaBits == 0;

// `To` supports more exponents near zero which means that some subnormal
// values in `From` may become normal.
@@ -1473,11 +1478,14 @@ struct ConvertImpl<From, To, kSaturate, kTruncate,
}

// Truncate/round mantissa if necessary.
- if constexpr (kDigitShift > 0) {
+ if constexpr (kDigitShift >= 0) {
bits <<= kDigitShift;
} else {
if constexpr (!kTruncate) {
- bits = RoundBitsToNearestEven(bits, -kDigitShift);
+ // When converting float to e8m0, the bits represent a denormal,
+ // so don't use the implicit mantissa bit for rounding.
+ bits = RoundBitsToNearestEven(
+ bits, -kDigitShift, to_zero_mantissa && kExponentOffset != 0);
}
bits >>= -kDigitShift;
}
@@ -1514,8 +1522,8 @@ struct ConvertImpl<From, To, kSaturate, kTruncate,
// otherwise the lower precision bits may already be lost. There
// is an edge-case where rounding to a normalized value would
// normally round down, but for a subnormal, we need to round up.
- rounded_from_bits =
- RoundBitsToNearestEven(rounded_from_bits, exponent_shift);
+ rounded_from_bits = RoundBitsToNearestEven(rounded_from_bits,
+ exponent_shift, false);
}
bits = rounded_from_bits >> exponent_shift;
}
@@ -1532,7 +1540,8 @@ struct ConvertImpl<From, To, kSaturate, kTruncate,
WideBits rounded_from_bits = from_bits;
if constexpr (kDigitShift < 0) {
if constexpr (!kTruncate) {
- rounded_from_bits = RoundBitsToNearestEven(from_bits, -kDigitShift);
+ rounded_from_bits =
+ RoundBitsToNearestEven(from_bits, -kDigitShift, to_zero_mantissa);
}
// Zero-out tail bits.
rounded_from_bits &= ~((WideBits{1} << (-kDigitShift)) - 1);
@@ -1602,7 +1611,7 @@ struct ConvertImpl<Eigen::half, float8_e5m2, kSaturate, kTruncate> {
}

if constexpr (!kTruncate) {
- from_bits = RoundBitsToNearestEven(from_bits, 8);
+ from_bits = RoundBitsToNearestEven(from_bits, 8, false);
// Rounding can cause an overflow to infinity. Clamp to the largest finite
// value if saturation is requested.
if constexpr (kSaturate) {
1 change: 1 addition & 0 deletions third_party/tsl/third_party/py/ml_dtypes/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def repo():
tf_http_archive(
name = "ml_dtypes",
build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD",
patch_file = ["//third_party/py/ml_dtypes:e8m0.patch"],
link_files = {
"//third_party/py/ml_dtypes:ml_dtypes.tests.BUILD": "tests/BUILD.bazel",
"//third_party/py/ml_dtypes:LICENSE": "LICENSE",
Expand Down
1 change: 1 addition & 0 deletions third_party/tsl/tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
using float8_e5m2 = ::ml_dtypes::float8_e5m2;
using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz;
using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu;

using int2 = ::ml_dtypes::int2;
using uint2 = ::ml_dtypes::uint2;
Expand Down
14 changes: 14 additions & 0 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,20 @@ TEST(Array2dTest, LinspaceF4E2M1FN) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, LinspaceF8E8M0FNU) {
auto arr = MakeLinspaceArray2D<tsl::float8_e8m0fnu>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 2.0); // 1.5 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 4.0); // 3.0 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, Stringification) {
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
const std::string expected = R"([[1, 1.5],
Expand Down
9 changes: 7 additions & 2 deletions xla/comparison_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,13 @@ class Comparison {
// -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN
// Reference:
// https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
if constexpr (std::numeric_limits<T>::is_signed) {
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
} else {
using R = UnsignedIntegerTypeForSizeType<sizeof(T)>;
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
}
}
}
// Applies the comparison from this Comparison's direction and ordering.
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ inline std::ostream& operator<<(std::ostream& os,
return os << "F8E5M2FNUZ";
case XLA_FFI_DataType_F8E4M3FNUZ:
return os << "F8E4M3FNUZ";
case XLA_FFI_DataType_F8E8M0FNU:
return os << "F8E8M0FNU";
}
}

Expand Down
1 change: 1 addition & 0 deletions xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ typedef enum {
XLA_FFI_DataType_F8E5M2FNUZ = 24,
XLA_FFI_DataType_F8E4M3FNUZ = 25,
XLA_FFI_DataType_F4E2M1FN = 30,
XLA_FFI_DataType_F8E8M0FNU = 31,
} XLA_FFI_DataType;
// LINT.ThenChange(ffi_test.cc)

Expand Down
3 changes: 3 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ enum class DataType : uint8_t {
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
F8E3M4 = XLA_FFI_DataType_F8E3M4,
F4E2M1FN = XLA_FFI_DataType_F4E2M1FN,
F8E8M0FNU = XLA_FFI_DataType_F8E8M0FNU,
};

// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
Expand Down Expand Up @@ -108,6 +109,7 @@ inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
inline constexpr DataType F8E3M4 = DataType::F8E3M4;
inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN;
inline constexpr DataType F8E8M0FNU = DataType::F8E8M0FNU;

inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
return os << static_cast<XLA_FFI_DataType>(dtype);
Expand All @@ -130,6 +132,7 @@ constexpr size_t ByteWidth(DataType dtype) {
case DataType::F8E4M3FNUZ:
case DataType::F8E3M4:
case DataType::F4E2M1FN:
case DataType::F8E8M0FNU:
return 1;
case DataType::S16:
case DataType::U16:
Expand Down
3 changes: 3 additions & 0 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ TEST(FfiTest, DataTypeEnumValue) {
EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4));
EXPECT_EQ(encoded(PrimitiveType::F8E8M0FNU), encoded(DataType::F8E8M0FNU));
}

TEST(FfiTest, DataTypeByteWidth) {
Expand Down Expand Up @@ -196,6 +197,8 @@ TEST(FfiTest, DataTypeByteWidth) {
ByteWidth(DataType::F8E4M3FNUZ));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4),
ByteWidth(DataType::F8E3M4));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E8M0FNU),
ByteWidth(DataType::F8E8M0FNU));
}

TEST(FfiTest, ErrorEnumValue) {
Expand Down
1 change: 1 addition & 0 deletions xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
case PrimitiveType::F8E5M2FNUZ:
case PrimitiveType::F8E4M3FNUZ:
case PrimitiveType::F8E3M4:
case PrimitiveType::F8E8M0FNU:
return static_cast<XLA_FFI_DataType>(primitive_type);
default:
DCHECK(false) << "Unsupported primitive type "
Expand Down
17 changes: 17 additions & 0 deletions xla/fp_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,23 @@ TEST(FPDistanceTest, F4E2M1FNDistance) {
4);
}

TEST(FPDistanceTest, F8E8M0FNUDistance) {
// a & b are equal
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e8m0fnu>(
tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(1.0)),
0);

// one step apart
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e8m0fnu>(
tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(2.0)),
1);

// two steps apart
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e8m0fnu>(
tsl::float8_e8m0fnu(0.5), tsl::float8_e8m0fnu(2.0)),
2);
}

TEST(FPDistanceTest, F8E3M4Distance) {
// a & b are equal
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/evaluator/hlo_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3722,7 +3722,7 @@ absl::StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
const Shape& result_shape) {
std::function<ResultT(Fp, Uint)> stochastic_convert_op =
[](Fp operand, Uint random) -> ResultT {
bool is_negative = static_cast<bool>(Eigen::numext::signbit(operand));
bool is_negative = static_cast<bool>(SignAndMagnitude(operand).first);
if (Eigen::numext::isinf(operand)) {
return is_negative ? std::numeric_limits<ResultT>::min()
: std::numeric_limits<ResultT>::max();
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1742,6 +1742,7 @@ extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e3m4, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e8m0fnu, float>;

} // namespace xla

Expand Down
1 change: 1 addition & 0 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ limitations under the License.

namespace xla {
template class HloEvaluatorTypedVisitor<tsl::float4_e2m1fn, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e8m0fnu, float>;
} // namespace xla
59 changes: 33 additions & 26 deletions xla/hlo/transforms/expanders/comparison_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,34 +115,41 @@ absl::StatusOr<HloInstruction*> ComparisonExpander::ExpandInstruction(
ShapeUtil::ChangeElementType(rhs->shape(), compare_type), rhs));
}

int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type());
PrimitiveType signed_type =
primitive_util::SignedIntegralTypeForBitWidth(bit_width);
auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type);

auto zero_value = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type)));
zero_value = computation->AddInstruction(
HloInstruction::CreateBroadcast(signed_shape, zero_value, {}));

auto min_value = computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::MinValue(signed_shape.element_type())));
min_value = computation->AddInstruction(
HloInstruction::CreateBroadcast(signed_shape, min_value, {}));

auto max_value = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type)));
max_value = computation->AddInstruction(
HloInstruction::CreateBroadcast(signed_shape, max_value, {}));

lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value,
min_value, max_value);
rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value,
min_value, max_value);
if (compare_type != F8E8M0FNU) {
int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type());
PrimitiveType signed_type =
primitive_util::SignedIntegralTypeForBitWidth(bit_width);
auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type);

auto zero_value = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type)));
zero_value = computation->AddInstruction(
HloInstruction::CreateBroadcast(signed_shape, zero_value, {}));

auto min_value = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::MinValue(signed_type)));
min_value = computation->AddInstruction(
HloInstruction::CreateBroadcast(signed_shape, min_value, {}));

auto max_value = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type)));
max_value = computation->AddInstruction(
HloInstruction::CreateBroadcast(signed_shape, max_value, {}));

lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value,
min_value, max_value);
rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value,
min_value, max_value);
} else {
auto int8_shape = ShapeUtil::ChangeElementType(lhs->shape(), U8);
lhs = computation->AddInstruction(
HloInstruction::CreateBitcastConvert(int8_shape, lhs));
rhs = computation->AddInstruction(
HloInstruction::CreateBitcastConvert(int8_shape, rhs));
}

auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare(
instruction->shape(), lhs, rhs, compare->direction(),
Comparison::Type::kSigned));
instruction->shape(), lhs, rhs, compare->direction()));

VLOG(2) << "New comparison instruction for total order:"
<< new_compare->ToString();
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/transforms/simplifiers/float_normalization_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class FloatNormalizationF8Test
INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test,
::testing::Values(F4E2M1FN, F8E3M4, F8E4M3,
F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ,
F8E5M2, F8E5M2FNUZ));
F8E5M2, F8E5M2FNUZ, F8E8M0FNU));

TEST_F(FloatNormalizationTest, NoopIfSupported) {
auto builder = HloComputation::Builder(TestName());
Expand Down
11 changes: 10 additions & 1 deletion xla/hlo/translate/hlo_to_mhlo/tests/import.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ add {

// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN>
%constant.14 = f4e2m1fn[4] constant({1, 2, 3, 4})

// CHECK: %[[VAL_15:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00]> : tensor<4xf8E8M0FNU>
%constant.15 = f8e8m0fnu[4] constant({1, 2, 4, 8})
}

// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
Expand Down Expand Up @@ -551,7 +554,13 @@ add {
%convert.17 = f4e2m1fn[4] convert(f32[4] %convert.16)

// CHECK-NEXT: %15 = mhlo.convert %14 : (tensor<4xf4E2M1FN>) -> tensor<4xf32>
ROOT %convert.18 = f32[4] convert(f4e2m1fn[4] %convert.17)
%convert.18 = f32[4] convert(f4e2m1fn[4] %convert.17)

// CHECK-NEXT: %16 = mhlo.convert %15 : (tensor<4xf32>) -> tensor<4xf8E8M0FNU>
%convert.19 = f8e8m0fnu[4] convert(f32[4] %convert.18)

// CHECK-NEXT: %17 = mhlo.convert %16 : (tensor<4xf8E8M0FNU>) -> tensor<4xf32>
ROOT %convert.20 = f32[4] convert(f8e8m0fnu[4] %convert.19)
}

// CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8>
Expand Down
11 changes: 9 additions & 2 deletions xla/hlo/translate/mhlo_to_hlo/tests/export.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,9 @@ func.func @main() {
// CHECK: f4e2m1fn[4] constant({1, 2, 3, 4})
%cst_18 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN>

// CHECK: f8e8m0fnu[4] constant({1, 2, 4, 8})
%cst_19 = arith.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00]> : tensor<4xf8E8M0FNU>

func.return
}

Expand Down Expand Up @@ -744,7 +747,9 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32>
%12 = "mhlo.convert"(%11) : (tensor<2xf32>) -> tensor<2xf4E2M1FN>
%13 = "mhlo.convert"(%12) : (tensor<2xf4E2M1FN>) -> tensor<2xf32>
func.return %13 : tensor<2xf32>
%14 = "mhlo.convert"(%13) : (tensor<2xf32>) -> tensor<2xf8E8M0FNU>
%15 = "mhlo.convert"(%14) : (tensor<2xf8E8M0FNU>) -> tensor<2xf32>
func.return %15 : tensor<2xf32>
}

// CHECK: ENTRY
Expand All @@ -762,7 +767,9 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]])
// CHECK: %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]])
// CHECK: %[[E2M1FN_VAL:.*]] = f4e2m1fn[2] convert(f32[2] %[[F32_VAL6]])
// CHECK: ROOT %[[F32_VAL7:.*]] = f32[2] convert(f4e2m1fn[2] %[[E2M1FN_VAL]])
// CHECK: %[[F32_VAL7:.*]] = f32[2] convert(f4e2m1fn[2] %[[E2M1FN_VAL]])
// CHECK: %[[E8M0FNU_VAL:.*]] = f8e8m0fnu[2] convert(f32[2] %[[F32_VAL7]])
// CHECK: ROOT %[[F32_VAL8:.*]] = f32[2] convert(f8e8m0fnu[2] %[[E8M0FNU_VAL]])

// -----

Expand Down
Loading

0 comments on commit e0ee48c

Please sign in to comment.