From ca16839096feb93e0454ec380c5c707c30199346 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Wed, 6 Nov 2024 09:28:25 +0100 Subject: [PATCH] Add F4E2M1FN type: conversion codegen --- xla/service/elemental_ir_emitter.cc | 106 ++++++++++- xla/service/elemental_ir_emitter_test.cc | 13 +- xla/service/float8_fnuz_ir_emitter.cc | 14 +- .../fusions/transforms/expand_float_ops.cc | 105 +++++------ .../transforms/tests/expand_float_ops.mlir | 37 ++++ .../gpu/tests/float_conversions_test.cc | 4 +- xla/tests/convert_test.cc | 170 +++++++++++++++++- 7 files changed, 383 insertions(+), 66 deletions(-) diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index d1276e1717bab..5279057bfb4b0 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -809,6 +809,82 @@ llvm::Value* EmitF8e4m3b11fnuzToF16(llvm::Value* f8_value, return f16_value; } +absl::StatusOr EmitF16ToF4e2m1fn(llvm::Value* f16_value, + llvm::IRBuilder<>* b) { + TF_ASSIGN_OR_RETURN( + llvm::Value * reduced_precision, + EmitReducePrecisionIR( + /*src_ty=*/F16, f16_value, + /*dest_exponent_bits=*/primitive_util::ExponentWidth(F4E2M1FN) + 1, + /*dest_mantissa_bits=*/primitive_util::SignificandWidth(F4E2M1FN) - 1, + /*quiet_nans=*/false, b)); + llvm::Value* as_int16 = b->CreateBitCast(reduced_precision, b->getInt16Ty()); + llvm::Value* as_int8 = + b->CreateTrunc(b->CreateLShr(as_int16, 9), b->getInt8Ty()); + + // Extract sign, exponent and mantissa from reduced precision value. + auto i8_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt8Ty(), val); + }; + llvm::Value* f4_sign = b->CreateLShr(as_int8, 6); + llvm::Value* f4_bits = b->CreateAnd(as_int8, i8_const(0x3F)); + llvm::Value* f4_normal = b->CreateSub(f4_bits, i8_const(28)); + + // Special case for exponent overflow. + auto i16_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt16Ty(), val); + }; + llvm::Value* f16_bits = b->CreateAnd( + b->CreateBitCast(f16_value, b->getInt16Ty()), i16_const(0x7FFF)); + llvm::Value* is_overflow = + b->CreateICmpUGE(f16_bits, i16_const(0x4700)); // 7.0 + llvm::Value* is_nan = b->CreateICmpUGT(f16_bits, i16_const(0x7C00)); // inf + llvm::Value* max_or_nan = + b->CreateSelect(is_nan, i8_const(0x8), i8_const(0x7)); + llvm::Value* f4_normal_or_overflow = + b->CreateSelect(is_overflow, max_or_nan, f4_normal); + + // Special case for exponent underflow. + llvm::Value* is_underflow = b->CreateICmpSLE(f4_normal, i8_const(1)); + llvm::Value* is_one = b->CreateICmpUGE(f16_bits, i16_const(0x3A00)); // 0.75 + llvm::Value* is_zero = b->CreateICmpULE(f16_bits, i16_const(0x3400)); // 0.25 + llvm::Value* denorm_or_zero = + b->CreateSelect(is_zero, i8_const(0x0), i8_const(0x1)); + llvm::Value* f4_small = + b->CreateSelect(is_one, i8_const(0x2), denorm_or_zero); + llvm::Value* f4_result = + b->CreateSelect(is_underflow, f4_small, f4_normal_or_overflow); + + // Add sign to the resulting value. + return b->CreateOr(f4_result, b->CreateShl(f4_sign, 3)); +} + +llvm::Value* EmitF4e2m1fnToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { + llvm::Value* as_int16 = b->CreateZExt(f8_value, b->getInt16Ty()); + + // Extract sign, exponent and mantissa from reduced precision value. + auto i16_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt16Ty(), val); + }; + llvm::Value* sign = b->CreateLShr(as_int16, 3); + llvm::Value* sign_shifted = b->CreateShl(sign, 15); + llvm::Value* bits = b->CreateAnd(as_int16, i16_const(0x7)); + llvm::Value* bits_shifted = b->CreateShl(bits, 9); + + // Re-bias the exponent and handle denormals. + llvm::Value* f16_normal = b->CreateAdd(bits_shifted, i16_const(14 << 10)); + llvm::Value* is_denorm_or_zero = b->CreateICmpULE(bits, i16_const(1)); + llvm::Value* is_zero = b->CreateICmpEQ(bits, i16_const(0)); + llvm::Value* denorm_or_zero = + b->CreateSelect(is_zero, i16_const(0x0000), i16_const(0x3800)); + llvm::Value* f16_result = + b->CreateSelect(is_denorm_or_zero, denorm_or_zero, f16_normal); + + // Add sign to the resulting value. + llvm::Value* f16_signed = b->CreateOr(f16_result, sign_shifted); + return b->CreateBitCast(f16_signed, b->getHalfTy()); +} + llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, PrimitiveType from_type, PrimitiveType to_type, llvm::Module* module, @@ -902,6 +978,12 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), b_); } + if (to_type == F4E2M1FN) { + return EmitF16ToF4e2m1fn( + EmitIntegralToFloating(operand_value, from_type, F16, module_, + b_), + b_); + } if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz( F16, @@ -1105,6 +1187,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F4E2M1FN) { + TF_RET_CHECK(to_type != F4E2M1FN); + operand_value = EmitF4e2m1fnToF16(operand_value, b_); + from_type = F16; + if (from_type == to_type) { + return operand_value; + } + } if (from_type == F8E5M2FNUZ || from_type == F8E4M3FNUZ) { TF_RET_CHECK(to_type != from_type); PrimitiveType cast_type = @@ -1176,6 +1266,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitF16ToF8e4m3b11fnuz(operand_value, b_); } + if (to_type == F4E2M1FN) { + // Cast to F16 first. Casts to F4E2M1FN must be from F16. + if (from_type != F16) { + operand_value = b_->CreateFPCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + } + return EmitF16ToF4e2m1fn(operand_value, b_); + } if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz(from_type, operand_value, to_type, b_); } @@ -1721,6 +1819,9 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( } else if (operand_type == F8E4M3FN) { lhs_value = EmitF8e4m3fnToF16(lhs_value, b_); rhs_value = EmitF8e4m3fnToF16(rhs_value, b_); + } else if (operand_type == F4E2M1FN) { + lhs_value = EmitF4e2m1fnToF16(lhs_value, b_); + rhs_value = EmitF4e2m1fnToF16(rhs_value, b_); } else if (operand_type == F8E5M2FNUZ || operand_type == F8E4M3FNUZ) { TF_ASSIGN_OR_RETURN( lhs_value, @@ -3569,9 +3670,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( primitive_util::IsFloatingPointType(component_element_type)) << component_element_type; llvm::Type* float_ir_type; - if (component_element_type == F8E4M3FNUZ) { - float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_); - } else if (component_element_type == F8E5M2FNUZ) { + if (component_element_type == F8E4M3FNUZ || + component_element_type == F8E5M2FNUZ) { float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_); } else { float_ir_type = diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index c1edb9a4b856d..fe1b8f17efbb1 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -99,9 +99,10 @@ class ElementalIrEmitterExecutionTypedTest }; using FloatTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ElementalIrEmitterExecutionTypedTest, FloatTypes); @@ -614,7 +615,8 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { std::is_same() || std::is_same() || std::is_same() || - std::is_same()) { + std::is_same() || + std::is_same()) { GTEST_SKIP() << "Skipping test for type " << tname; } const auto hlo_text = absl::StrReplaceAll(R"( @@ -629,6 +631,9 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { TYPED_TEST(ElementalIrEmitterExecutionTypedTest, BatchDotFloat) { auto tname = this->TypeName(); + if (std::is_same()) { + GTEST_SKIP() << "Dot operation on E2M1 is not supported"; + } const auto hlo_text = absl::StrReplaceAll(R"( HloModule matmul diff --git a/xla/service/float8_fnuz_ir_emitter.cc b/xla/service/float8_fnuz_ir_emitter.cc index 4afb96362cf86..b0c012a75f9f0 100644 --- a/xla/service/float8_fnuz_ir_emitter.cc +++ b/xla/service/float8_fnuz_ir_emitter.cc @@ -40,6 +40,8 @@ namespace { absl::StatusOr PrimitiveTypeToAPFloatSemantics( PrimitiveType type) { switch (type) { + case F4E2M1FN: + return &llvm::APFloat::Float4E2M1FN(); case F8E3M4: return &llvm::APFloat::Float8E3M4(); case F8E4M3: @@ -72,6 +74,8 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilderBase* b, PrimitiveType type) { switch (type) { + case F4E2M1FN: + return b->getIntNTy(4); case F8E3M4: case F8E4M3: case F8E4M3B11FNUZ: @@ -649,8 +653,14 @@ absl::StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, llvm::ConstantInt::get(b->getInt8Ty(), 0x0u), sign); // Bitwise or the sign bit back in. - sign = b->CreateZExt(sign, output_int_type); - sign = b->CreateShl(sign, output_type_bit_width - BitWidth(input_type)); + int shift = output_type_bit_width - BitWidth(input_type); + if (shift >= 0) { + sign = b->CreateZExt(sign, output_int_type); + sign = b->CreateShl(sign, shift); + } else { + sign = b->CreateLShr(sign, -shift); + sign = b->CreateTrunc(sign, output_int_type); + } llvm::Value* result = b->CreateOr(sign, result_abs); // Bitcast to the output type. diff --git a/xla/service/gpu/fusions/transforms/expand_float_ops.cc b/xla/service/gpu/fusions/transforms/expand_float_ops.cc index 6fea3a97527f9..d4e69d2641775 100644 --- a/xla/service/gpu/fusions/transforms/expand_float_ops.cc +++ b/xla/service/gpu/fusions/transforms/expand_float_ops.cc @@ -166,6 +166,11 @@ int GetExponentBias(mlir::FloatType ty) { return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()); } +bool IsFNUZ(mlir::FloatType ty) { + return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() || + ty.isFloat8E5M2FNUZ(); +} + Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { auto ty = mlir::cast(value.getType()); if (mlir::LLVM::isCompatibleOuterType(ty)) { @@ -175,7 +180,7 @@ Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { return b.create(ma::CmpFPredicate::OEQ, value, inf); } - assert(ty.getIntOrFloatBitWidth() == 8); + assert(ty.getIntOrFloatBitWidth() <= 8); // F8E5M2, F8E4M3, F8E3M4 are the only 8 bit float with infinities. if (ty.isFloat8E5M2()) { Val bits{b.create(b.getI8Type(), value), &b}; @@ -196,6 +201,9 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { if (mlir::LLVM::isCompatibleOuterType(ty)) { return b.create(ma::CmpFPredicate::UNO, value, value); } + if (ty.isFloat4E2M1FN()) { + return b.create(false, b.getI1Type()); + } assert(ty.getIntOrFloatBitWidth() == 8); Val bits{b.create(b.getI8Type(), value), &b}; @@ -281,7 +289,7 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, auto to_int_ty = b.getIntegerType(to_ty.getIntOrFloatBitWidth()); mlir::IntegerType wide_int_ty; - if (from_ty.getWidth() == 8 && to_ty.getWidth() == 8) { + if (from_ty.getWidth() <= 8 && to_ty.getWidth() <= 8) { wide_int_ty = b.getI16Type(); } else { wide_int_ty = b.getIntegerType( @@ -300,21 +308,20 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, int64_t exp_offset = to_bias - from_bias; int digit_shift = to_mantissa - from_mantissa; - Val from_bits{ - b.create( - b.getIntegerType(value.getType().getIntOrFloatBitWidth()), value), - &b}; + int from_width = value.getType().getIntOrFloatBitWidth(); + Val from_bits{b.create(b.getIntegerType(from_width), value), + &b}; + if (from_width < 8) { + from_bits = convert_int(b.getIntegerType(8), from_bits); + } auto cst = [&](mlir::Type ty, int64_t n) -> Val { return {b.create(n, ty), &b}; }; // Shift bits to destination type, without sign bit. - Val from_sign_bit = - from_bits.shrui(value.getType().getIntOrFloatBitWidth() - 1) != 0; - - from_bits = - from_bits & ((1ULL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1); + Val from_sign_bit = from_bits.shrui(from_width - 1) != 0; + from_bits = from_bits & ((1ULL << (from_width - 1)) - 1); Value result_is_inf = IsInf(value, b); Value input_is_nan = IsNaN(value, b); @@ -327,6 +334,12 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, Value to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); Val to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); + // MX float types have neither infinities nor NaNs. + if (to_ty.isFloat4E2M1FN()) { + to_inf = cst_bits(llvm::APFloat::getLargest(to_ty.getFloatSemantics())); + to_nan = to_zero | 0x8; + } + auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) { assert(bits.value.getType() == roundoff.value.getType()); // Round to nearest even by adding a bias term. @@ -394,10 +407,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, Val bits = convert_int(wide_int_ty, from_bits); // Determine exponent in target type. - Value normalization_factor = - convert_int(i32_ty, - b.create(from_bits)) - - (from_int_ty.getWidth() - from_mantissa - 1); + Value clz = convert_int( + i32_ty, b.create(from_bits)); + Value msb = cst(i32_ty, std::max(from_width, 8) - 1) - clz; + Value normalization_factor = cst(i32_ty, from_mantissa) - msb; Val biased_exponent = cst(i32_ty, exp_offset + 1) - normalization_factor; // If the result is subnormal, adjust the subnormal bits to account for @@ -469,18 +482,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, result); } - // Handle types with no unsigned zero. - auto is_nuz = [](mlir::FloatType ty) { - return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() || - ty.isFloat8E5M2FNUZ(); - }; - - if (is_nuz(to_ty)) { + if (IsFNUZ(to_ty)) { // Clear the sign bit if the result is zero (the output has no negative - // zero). - Val result_is_non_zero = Val{result, &b} != 0; + // zero). Handle the edge case when the input is zero and the result is not. + Val result_is_non_zero = + (digit_shift > 0 ? from_bits : Val{result, &b}) != 0; from_sign_bit = from_sign_bit & result_is_non_zero; - } else if (is_nuz(from_ty)) { + } else if (IsFNUZ(from_ty)) { // Clear the sign bit if the input is NaN (it's positive but encoded as // negative 0). from_sign_bit = from_sign_bit ^ input_is_nan; @@ -506,8 +514,8 @@ struct RewriteTruncFPattern : public mlir::OpRewritePattern { using FloatValue = mlir::TypedValue; auto src = mlir::cast(op.getOperand()); auto dst_ty = mlir::cast(op.getType()); - if (dst_ty.getWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit truncf"); + if (dst_ty.getWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) truncf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -524,8 +532,8 @@ struct RewriteExtFPattern : public mlir::OpRewritePattern { using FloatValue = mlir::TypedValue; auto src = mlir::cast(op.getOperand()); auto dst_ty = mlir::cast(op.getType()); - if (src.getType().getWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit extf"); + if (src.getType().getWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) extf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -544,8 +552,8 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { auto lhs = mlir::cast(op.getLhs()); auto rhs = mlir::cast(op.getRhs()); - if (lhs.getType().getWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit cmpf"); + if (lhs.getType().getWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) cmpf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -553,16 +561,16 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { llvm::APFloat rhs_cst(rhs.getType().getFloatSemantics()); if (op.getPredicate() == ma::CmpFPredicate::UNE && mlir::matchPattern(rhs, mlir::m_ConstantFloat(&rhs_cst))) { - Val int_value{b.create(rewriter.getI8Type(), lhs), &b}; + mlir::Type int_ty = rewriter.getIntegerType(lhs.getType().getWidth()); + Val int_value{b.create(int_ty, lhs), &b}; int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue(); // If we're comparing to +-0, compare the absolute values. - if (rhs_cst.isZero() && - (lhs.getType().isFloat8E3M4() || lhs.getType().isFloat8E4M3() || - lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) { - int_value = int_value & 0x7f; - constant &= 0x7f; + if (rhs_cst.isZero() && !IsFNUZ(lhs.getType())) { + int64_t mask = (1 << (lhs.getType().getWidth() - 1)) - 1; + int_value = int_value & mask; + constant &= mask; } - auto cst = b.create(constant, rewriter.getI8Type()); + auto cst = b.create(constant, int_ty); rewriter.replaceOpWithNewOp(op, ma::CmpIPredicate::ne, int_value, cst); return mlir::success(); @@ -586,18 +594,15 @@ struct RewriteAbsFPattern : public mlir::OpRewritePattern { auto src = mlir::cast(op.getOperand()); // LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16. // Once that's removed, remove the code for BF16 here. - if (src.getType().getWidth() != 8 && !src.getType().isBF16()) { - return rewriter.notifyMatchFailure(op, "not an f8 or bf16 absf"); + if (src.getType().getWidth() > 8 && !src.getType().isBF16()) { + return rewriter.notifyMatchFailure(op, + "not an f8 (or less) or bf16 absf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth()); Val value{b.create(i_ty, src), &b}; - if (src.getType().getWidth() == 8) { - value = value & 0x7f; - } else { - CHECK(src.getType().isBF16()); - value = value & 0x7fff; - } + int64_t mask = (1ull << (src.getType().getWidth() - 1)) - 1; + value = value & mask; rewriter.replaceOpWithNewOp(op, src.getType(), value); return mlir::success(); } @@ -609,8 +614,8 @@ struct RewriteIToFpPattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter& rewriter) const override { - if (op.getType().getIntOrFloatBitWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an f8 itofp"); + if (op.getType().getIntOrFloatBitWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an f8 (or less) itofp"); } Value to_float = rewriter.create(op.getLoc(), rewriter.getF32Type(), op.getIn()); @@ -625,8 +630,8 @@ struct RewriteFpToIPattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter& rewriter) const override { - if (op.getIn().getType().getIntOrFloatBitWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an f8 fptoi"); + if (op.getIn().getType().getIntOrFloatBitWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an f8 (or less) fptoi"); } Value to_f32 = rewriter.create( op.getLoc(), rewriter.getF32Type(), op.getIn()); diff --git a/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir b/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir index ff97c1b46f708..2b2d8b9a368e3 100644 --- a/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir +++ b/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir @@ -115,3 +115,40 @@ module { // CHECK: %[[EXT:.*]] = arith.extf {{.*}} : bf16 to f32 // CHECK: arith.truncf %[[EXT]] : f32 to f16 // CHECK-NOT: arith.truncf + +// ----- + +module { + func.func @f4_to_f16(%arg0: f4E2M1FN) -> f16 { + %ret = arith.extf %arg0 : f4E2M1FN to f16 + return %ret : f16 + } +} + +// CHECK-LABEL: @f4_to_f16 +// CHECK-NOT: arith.extf + +// ----- + +module { + func.func @f16_to_f4(%arg0: f16) -> f4E2M1FN { + %ret = arith.truncf %arg0 : f16 to f4E2M1FN + return %ret : f4E2M1FN + } +} + +// CHECK-LABEL: @f16_to_f4 +// CHECK-NOT: arith.truncf + +// ----- + +module { + func.func @f4_abs(%arg0: f4E2M1FN) -> f4E2M1FN { + %ret = math.absf %arg0 : f4E2M1FN + return %ret : f4E2M1FN + } +} + +// CHECK-LABEL: @f4_abs +// CHECK-NOT: math.absf +// CHECK: arith.constant 7 : i4 diff --git a/xla/service/gpu/tests/float_conversions_test.cc b/xla/service/gpu/tests/float_conversions_test.cc index 16383324dfb01..62179eefcf3c5 100644 --- a/xla/service/gpu/tests/float_conversions_test.cc +++ b/xla/service/gpu/tests/float_conversions_test.cc @@ -29,8 +29,8 @@ class FloatConversionParamTest INSTANTIATE_TEST_SUITE_P(FloatConversionParamSuite, FloatConversionParamTest, ::testing::Values("f64", "f32", "f16", "bf16", - "f8e5m2", "f8e5m2fnuz", "f8e4m3", - "f8e4m3fn", "f8e4m3fnuz", + "f4e2m1fn", "f8e5m2", "f8e5m2fnuz", + "f8e4m3", "f8e4m3fn", "f8e4m3fnuz", "f8e4m3b11fnuz", "f8e3m4")); TEST_P(FloatConversionParamTest, FloatToF16) { diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index 4f06ea0cc290c..75d9653c3e700 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -53,10 +53,10 @@ class ConvertTestT : public ConvertTest { public: using ConvertTest::ConvertTest; }; -using FloatingPointTypeList = - ::testing::Types; +using FloatingPointTypeList = ::testing::Types< + tsl::float4_e2m1fn, tsl::float8_e3m4, tsl::float8_e4m3, tsl::float8_e4m3fn, + tsl::float8_e4m3fnuz, tsl::float8_e4m3b11fnuz, tsl::float8_e5m2, + tsl::float8_e5m2fnuz, Eigen::half, bfloat16, float, double>; TYPED_TEST_SUITE(ConvertTestT, FloatingPointTypeList); template @@ -741,7 +741,7 @@ XLA_TYPED_TEST(ConvertTestT, ConvertFPToPred) { XlaBuilder builder(this->TestName()); using FP = TypeParam; - auto a = ConstantR1(&builder, {FP{0.0}, FP{0.25}, FP{2.0}, FP{-0.0}}); + auto a = ConstantR1(&builder, {FP{0.0}, FP{0.5}, FP{2.0}, FP{-0.0}}); ConvertElementType(a, PRED); std::array expected = {false, true, true, false}; @@ -1925,5 +1925,165 @@ XLA_TYPED_TEST(ConvertTestF16, ConvertF8e3m4F16RoundtripExhaustive4) { this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F4E2M1FN + +XLA_TEST_F(ConvertTest, ConvertF16F4e2m1fnRoundtrip) { + // Convert from FP16 to FP4, then back to FP16. + XlaBuilder builder(TestName()); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, 0x1.8p2}, + // clang-format on + {0x1.4p0, 0x1p0}, // Round-to-even down + {0x1.Cp0, 0x1p1}, // Round-to-even up + {0x1.8p2, 0x1.8p2}, // Max value + {0x1.BFCp2, 0x1.8p2}, // Largest number that doesn't overflow + {0x1.Cp2, 0x1.8p2}, // Smallest number that overflows + {0x1p3, 0x1.8p2}, // Overflow + {0x1p0, 0x1p0}, // Smallest F8 normal + {0x1.8p-1, 0x1p0}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-1, 0x1.0p-1}, // Denormal without rounding + {0x1.4p-1, 0x1.0p-1}, // Round-to-even down + {0x1.Cp-1, 0x1.0p0}, // Round-to-even up + {0x1.6p-1, 0x1.0p-1}, // Round-to-nearest down + {0x1.Ep-1, 0x1.0p0}, // Round-to-nearest up + {0x1p-2, 0}, // Largest number that underflows + {0x1.004p-2, 0x1p-1}, // Smallest number that doesn't underflow + {0x1.7FCp-1, 0x1p-1}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); + } + + auto f4 = + ConvertElementType(ConstantR1(&builder, inputs), F4E2M1FN); + ConvertElementType(f4, F16); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F4e2m1fnRoundtrip)) { + // Convert from FP32 to FP4, then back to FP32. + XlaBuilder builder(TestName()); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, 0x1.8p2}, + // clang-format on + {0x1.4p0, 0x1p0}, // Round-to-even down + {0x1.Cp0, 0x1p1}, // Round-to-even up + {0x1.8p2, 0x1.8p2}, // Max value + {0x1.BFFFFEp2, 0x1.8p2}, // Largest number that doesn't overflow + {0x1.Cp2, 0x1.8p2}, // Smallest number that overflows + {0x1p3, 0x1.8p2}, // Overflow + {0x1p0, 0x1p0}, // Smallest F8 normal + {0x1.8p-1, 0x1p0}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-1, 0x1.0p-1}, // Denormal without rounding + {0x1.4p-1, 0x1.0p-1}, // Round-to-even down + {0x1.Cp-1, 0x1.0p0}, // Round-to-even up + {0x1.6p-1, 0x1.0p-1}, // Round-to-nearest down + {0x1.Ep-1, 0x1.0p0}, // Round-to-nearest up + {0x1p-2, 0}, // Largest number that underflows + {0x1.000002p-2, 0x1p-1}, // Smallest number that doesn't underflow + {0x1.7FFFFEp-1, 0x1p-1}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f4 = ConvertElementType(ConstantR1(&builder, inputs), F4E2M1FN); + ConvertElementType(f4, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF4e2m1fnRoundtripExhaustive) { + // Convert from FP4 to supported floating point type, then back to FP4. + XlaBuilder builder(this->TestName()); + + using From = tsl::float4_e2m1fn; + std::vector all_f4; + for (int i = 0; i < 16; i++) { + all_f4.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f4_as_fp = + ConvertElementType(ConstantR1(&builder, all_f4), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f4_as_fp, F4E2M1FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF4e2m1fnRoundtripExhaustive2) { + // Convert from supported floating point type to FP4. + XlaBuilder builder(this->TestName()); + + std::vector all_f4; + for (int i = 0; i < 16; i++) { + all_f4.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f4), F4E2M1FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF4e2m1fnRoundtripExhaustive3) { + // Convert from FP4 to supported floating point type. + XlaBuilder builder(this->TestName()); + + using From = tsl::float4_e2m1fn; + std::vector all_f4; + for (int i = 0; i < 16; i++) { + all_f4.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f4), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, ConvertF4e2m1fnF16RoundtripExhaustive4) { + // Convert from (B)F16 to FP4. + XlaBuilder builder(this->TestName()); + + std::vector all_f16; + for (int i = 0; i < 65536; i++) { + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f16), F4E2M1FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + } // namespace } // namespace xla