Skip to content

Commit

Permalink
Add F4E2M1FN type: conversion codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-kozub committed Dec 3, 2024
1 parent e919ed5 commit ca16839
Show file tree
Hide file tree
Showing 7 changed files with 383 additions and 66 deletions.
106 changes: 103 additions & 3 deletions xla/service/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,82 @@ llvm::Value* EmitF8e4m3b11fnuzToF16(llvm::Value* f8_value,
return f16_value;
}

absl::StatusOr<llvm::Value*> EmitF16ToF4e2m1fn(llvm::Value* f16_value,
llvm::IRBuilder<>* b) {
TF_ASSIGN_OR_RETURN(
llvm::Value * reduced_precision,
EmitReducePrecisionIR(
/*src_ty=*/F16, f16_value,
/*dest_exponent_bits=*/primitive_util::ExponentWidth(F4E2M1FN) + 1,
/*dest_mantissa_bits=*/primitive_util::SignificandWidth(F4E2M1FN) - 1,
/*quiet_nans=*/false, b));
llvm::Value* as_int16 = b->CreateBitCast(reduced_precision, b->getInt16Ty());
llvm::Value* as_int8 =
b->CreateTrunc(b->CreateLShr(as_int16, 9), b->getInt8Ty());

// Extract sign, exponent and mantissa from reduced precision value.
auto i8_const = [&](int val) {
return llvm::ConstantInt::get(b->getInt8Ty(), val);
};
llvm::Value* f4_sign = b->CreateLShr(as_int8, 6);
llvm::Value* f4_bits = b->CreateAnd(as_int8, i8_const(0x3F));
llvm::Value* f4_normal = b->CreateSub(f4_bits, i8_const(28));

// Special case for exponent overflow.
auto i16_const = [&](int val) {
return llvm::ConstantInt::get(b->getInt16Ty(), val);
};
llvm::Value* f16_bits = b->CreateAnd(
b->CreateBitCast(f16_value, b->getInt16Ty()), i16_const(0x7FFF));
llvm::Value* is_overflow =
b->CreateICmpUGE(f16_bits, i16_const(0x4700)); // 7.0
llvm::Value* is_nan = b->CreateICmpUGT(f16_bits, i16_const(0x7C00)); // inf
llvm::Value* max_or_nan =
b->CreateSelect(is_nan, i8_const(0x8), i8_const(0x7));
llvm::Value* f4_normal_or_overflow =
b->CreateSelect(is_overflow, max_or_nan, f4_normal);

// Special case for exponent underflow.
llvm::Value* is_underflow = b->CreateICmpSLE(f4_normal, i8_const(1));
llvm::Value* is_one = b->CreateICmpUGE(f16_bits, i16_const(0x3A00)); // 0.75
llvm::Value* is_zero = b->CreateICmpULE(f16_bits, i16_const(0x3400)); // 0.25
llvm::Value* denorm_or_zero =
b->CreateSelect(is_zero, i8_const(0x0), i8_const(0x1));
llvm::Value* f4_small =
b->CreateSelect(is_one, i8_const(0x2), denorm_or_zero);
llvm::Value* f4_result =
b->CreateSelect(is_underflow, f4_small, f4_normal_or_overflow);

// Add sign to the resulting value.
return b->CreateOr(f4_result, b->CreateShl(f4_sign, 3));
}

llvm::Value* EmitF4e2m1fnToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) {
llvm::Value* as_int16 = b->CreateZExt(f8_value, b->getInt16Ty());

// Extract sign, exponent and mantissa from reduced precision value.
auto i16_const = [&](int val) {
return llvm::ConstantInt::get(b->getInt16Ty(), val);
};
llvm::Value* sign = b->CreateLShr(as_int16, 3);
llvm::Value* sign_shifted = b->CreateShl(sign, 15);
llvm::Value* bits = b->CreateAnd(as_int16, i16_const(0x7));
llvm::Value* bits_shifted = b->CreateShl(bits, 9);

// Re-bias the exponent and handle denormals.
llvm::Value* f16_normal = b->CreateAdd(bits_shifted, i16_const(14 << 10));
llvm::Value* is_denorm_or_zero = b->CreateICmpULE(bits, i16_const(1));
llvm::Value* is_zero = b->CreateICmpEQ(bits, i16_const(0));
llvm::Value* denorm_or_zero =
b->CreateSelect(is_zero, i16_const(0x0000), i16_const(0x3800));
llvm::Value* f16_result =
b->CreateSelect(is_denorm_or_zero, denorm_or_zero, f16_normal);

// Add sign to the resulting value.
llvm::Value* f16_signed = b->CreateOr(f16_result, sign_shifted);
return b->CreateBitCast(f16_signed, b->getHalfTy());
}

llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
PrimitiveType from_type,
PrimitiveType to_type, llvm::Module* module,
Expand Down Expand Up @@ -902,6 +978,12 @@ absl::StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
b_),
b_);
}
if (to_type == F4E2M1FN) {
return EmitF16ToF4e2m1fn(
EmitIntegralToFloating(operand_value, from_type, F16, module_,
b_),
b_);
}
if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) {
return EmitFloatingToF8fnuz(
F16,
Expand Down Expand Up @@ -1105,6 +1187,14 @@ absl::StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
return operand_value;
}
}
if (from_type == F4E2M1FN) {
TF_RET_CHECK(to_type != F4E2M1FN);
operand_value = EmitF4e2m1fnToF16(operand_value, b_);
from_type = F16;
if (from_type == to_type) {
return operand_value;
}
}
if (from_type == F8E5M2FNUZ || from_type == F8E4M3FNUZ) {
TF_RET_CHECK(to_type != from_type);
PrimitiveType cast_type =
Expand Down Expand Up @@ -1176,6 +1266,14 @@ absl::StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
return EmitF16ToF8e4m3b11fnuz(operand_value, b_);
}
if (to_type == F4E2M1FN) {
// Cast to F16 first. Casts to F4E2M1FN must be from F16.
if (from_type != F16) {
operand_value = b_->CreateFPCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_));
}
return EmitF16ToF4e2m1fn(operand_value, b_);
}
if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) {
return EmitFloatingToF8fnuz(from_type, operand_value, to_type, b_);
}
Expand Down Expand Up @@ -1721,6 +1819,9 @@ absl::StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
} else if (operand_type == F8E4M3FN) {
lhs_value = EmitF8e4m3fnToF16(lhs_value, b_);
rhs_value = EmitF8e4m3fnToF16(rhs_value, b_);
} else if (operand_type == F4E2M1FN) {
lhs_value = EmitF4e2m1fnToF16(lhs_value, b_);
rhs_value = EmitF4e2m1fnToF16(rhs_value, b_);
} else if (operand_type == F8E5M2FNUZ || operand_type == F8E4M3FNUZ) {
TF_ASSIGN_OR_RETURN(
lhs_value,
Expand Down Expand Up @@ -3569,9 +3670,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
primitive_util::IsFloatingPointType(component_element_type))
<< component_element_type;
llvm::Type* float_ir_type;
if (component_element_type == F8E4M3FNUZ) {
float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_);
} else if (component_element_type == F8E5M2FNUZ) {
if (component_element_type == F8E4M3FNUZ ||
component_element_type == F8E5M2FNUZ) {
float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_);
} else {
float_ir_type =
Expand Down
13 changes: 9 additions & 4 deletions xla/service/elemental_ir_emitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ class ElementalIrEmitterExecutionTypedTest
};

using FloatTypes =
::testing::Types<bfloat16, tsl::float8_e5m2, tsl::float8_e5m2fnuz,
tsl::float8_e4m3, tsl::float8_e4m3fn, tsl::float8_e4m3fnuz,
tsl::float8_e4m3b11fnuz, tsl::float8_e3m4>;
::testing::Types<bfloat16, tsl::float4_e2m1fn, tsl::float8_e3m4,
tsl::float8_e4m3, tsl::float8_e4m3b11fnuz,
tsl::float8_e4m3fn, tsl::float8_e4m3fnuz, tsl::float8_e5m2,
tsl::float8_e5m2fnuz>;

TYPED_TEST_SUITE(ElementalIrEmitterExecutionTypedTest, FloatTypes);

Expand Down Expand Up @@ -614,7 +615,8 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) {
std::is_same<TypeParam, tsl::float8_e4m3>() ||
std::is_same<TypeParam, tsl::float8_e4m3fn>() ||
std::is_same<TypeParam, tsl::float8_e4m3b11fnuz>() ||
std::is_same<TypeParam, tsl::float8_e3m4>()) {
std::is_same<TypeParam, tsl::float8_e3m4>() ||
std::is_same<TypeParam, tsl::float4_e2m1fn>()) {
GTEST_SKIP() << "Skipping test for type " << tname;
}
const auto hlo_text = absl::StrReplaceAll(R"(
Expand All @@ -629,6 +631,9 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) {

TYPED_TEST(ElementalIrEmitterExecutionTypedTest, BatchDotFloat) {
auto tname = this->TypeName();
if (std::is_same<TypeParam, tsl::float4_e2m1fn>()) {
GTEST_SKIP() << "Dot operation on E2M1 is not supported";
}
const auto hlo_text = absl::StrReplaceAll(R"(
HloModule matmul
Expand Down
14 changes: 12 additions & 2 deletions xla/service/float8_fnuz_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ namespace {
absl::StatusOr<const llvm::fltSemantics*> PrimitiveTypeToAPFloatSemantics(
PrimitiveType type) {
switch (type) {
case F4E2M1FN:
return &llvm::APFloat::Float4E2M1FN();
case F8E3M4:
return &llvm::APFloat::Float8E3M4();
case F8E4M3:
Expand Down Expand Up @@ -72,6 +74,8 @@ absl::StatusOr<const llvm::fltSemantics*> PrimitiveTypeToAPFloatSemantics(
absl::StatusOr<llvm::Type*> PrimitiveTypeToLLVMType(llvm::IRBuilderBase* b,
PrimitiveType type) {
switch (type) {
case F4E2M1FN:
return b->getIntNTy(4);
case F8E3M4:
case F8E4M3:
case F8E4M3B11FNUZ:
Expand Down Expand Up @@ -649,8 +653,14 @@ absl::StatusOr<llvm::Value*> EmitF8fnuzToFloating(PrimitiveType input_type,
llvm::ConstantInt::get(b->getInt8Ty(), 0x0u), sign);

// Bitwise or the sign bit back in.
sign = b->CreateZExt(sign, output_int_type);
sign = b->CreateShl(sign, output_type_bit_width - BitWidth(input_type));
int shift = output_type_bit_width - BitWidth(input_type);
if (shift >= 0) {
sign = b->CreateZExt(sign, output_int_type);
sign = b->CreateShl(sign, shift);
} else {
sign = b->CreateLShr(sign, -shift);
sign = b->CreateTrunc(sign, output_int_type);
}
llvm::Value* result = b->CreateOr(sign, result_abs);

// Bitcast to the output type.
Expand Down
Loading

0 comments on commit ca16839

Please sign in to comment.