Skip to content

Commit

Permalink
Add F4E2M1FN type: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-kozub committed Dec 18, 2024
1 parent 999bf96 commit d7d5af7
Show file tree
Hide file tree
Showing 25 changed files with 186 additions and 70 deletions.
14 changes: 14 additions & 0 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ TEST(Array2dTest, LinspaceF8E3M4) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF4E2M1FN) {
auto arr = MakeLinspaceArray2D<tsl::float4_e2m1fn>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, Stringification) {
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
const std::string expected = R"([[1, 1.5],
Expand Down
26 changes: 17 additions & 9 deletions xla/backends/gpu/codegen/transforms/lower_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
Value linear_index, mlir::ImplicitLocOpBuilder& b) {
Type element_type = tensor.getType().getElementType();
if (element_type == b.getI4Type()) {
if (element_type.getIntOrFloatBitWidth() == 4) {
element_type = b.getI8Type();
}
auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext());
Expand Down Expand Up @@ -325,8 +325,9 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto linear_index = GetLinearIndex(op.getIndices(), b);
Type element_type = op.getTensor().getType().getElementType();

Value is_low_nibble = nullptr;
if (element_type == rewriter.getI4Type()) {
if (element_type.getIntOrFloatBitWidth() == 4) {
std::tie(linear_index, is_low_nibble) =
GetI4IndexAndNibble(linear_index, b);
}
Expand All @@ -341,7 +342,7 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
auto high_value = b.create<mlir::arith::ShRUIOp>(
load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
load = b.create<mlir::arith::TruncIOp>(
op.getType(),
rewriter.getI4Type(),
b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
}

Expand Down Expand Up @@ -377,6 +378,7 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {

auto source = mlir::dyn_cast<mlir::TypedValue<mlir::RankedTensorType>>(
op.getSource());
mlir::Type source_element_type = source.getType().getElementType();

mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto linear_index = GetLinearIndex(op.getIndices(), b);
Expand All @@ -385,7 +387,8 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
if (vector_type.getElementType().isInteger(1)) {
vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type());
}
if (op.getVectorType().getElementType().isInteger(4)) {
mlir::Type gep_element_type = vector_type.getElementType();
if (gep_element_type.getIntOrFloatBitWidth() == 4) {
linear_index = b.create<arith::ShRUIOp>(
linear_index,
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
Expand All @@ -397,11 +400,11 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
auto loaded =
b.create<mlir::LLVM::LoadOp>(llvm_vector_type, gep).getResult();

if (source.getType().getElementType().isInteger(1)) {
if (source_element_type.isInteger(1)) {
Value zero = b.create<mlir::arith::ConstantOp>(
mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0)));
loaded = b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, loaded, zero);
} else if (source.getType().getElementType().isInteger(4)) {
} else if (source_element_type.getIntOrFloatBitWidth() == 4) {
// LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
// elements.
loaded = PermutePairsInVector(loaded, b);
Expand Down Expand Up @@ -430,7 +433,7 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
auto scalar_value = op.getScalar();

// For i4 we store 2 values into one byte. This needs special handling here.
if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) {
if (tensor_dest.getType().getElementType().getIntOrFloatBitWidth() == 4) {
// We need to use directly op.getDest() as input, otherwise the following
// rewrite might remove the only user of it.
tensor_dest = op.getDest();
Expand All @@ -448,6 +451,10 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
auto tensor_dest_i8 =
b.create<UnrealizedConversionCastOp>(tensor_ty, tensor_dest)
.getResult(0);
if (scalar_value.getType() != rewriter.getI4Type()) {
scalar_value =
b.create<arith::BitcastOp>(rewriter.getI4Type(), scalar_value);
}
scalar_value = b.create<mlir::arith::ExtUIOp>(ty, scalar_value);

// We need AtomicRMWOp because it can happen that different threads try to
Expand Down Expand Up @@ -507,12 +514,13 @@ struct RewriteTransferWrite : OpRewritePattern<mlir::vector::TransferWriteOp> {
auto linear_index = GetLinearIndex(op.getIndices(), b);

mlir::Value vector_value = op.getVector();
if (op.getVectorType().getElementType().isInteger(1)) {
mlir::Type vector_element_type = op.getVectorType().getElementType();
if (vector_element_type.isInteger(1)) {
vector_value = b.create<arith::ExtUIOp>(
op.getVectorType().cloneWith(std::nullopt, b.getI8Type()),
vector_value);
}
if (op.getVectorType().getElementType().isInteger(4)) {
if (vector_element_type.getIntOrFloatBitWidth() == 4) {
linear_index = b.create<arith::ShRUIOp>(
linear_index,
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
Expand Down
53 changes: 53 additions & 0 deletions xla/fp_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,59 @@ class FP8E4M3DistanceTest : public ::testing::Test {};
using F8E4M3Types = ::testing::Types<tsl::float8_e4m3, tsl::float8_e4m3fn>;
TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types);

TEST(FPDistanceTest, F4E2M1FNDistance) {
// a & b are equal
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(4.0)),
0);

// a & b have the same exponents
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(6.0)),
1);

// a & b have different exponents
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
tsl::float4_e2m1fn(2.0), tsl::float4_e2m1fn(4.0)),
2);

// 1 from 0 in the positive direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
tsl::float4_e2m1fn(0)),
1);

// 1 from 0 in the negative direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
-std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
tsl::float4_e2m1fn(0)),
1);

// a & b have different signs
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
-std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
std::numeric_limits<tsl::float4_e2m1fn>::denorm_min()),
2);

// 1 non denorm from 0 in the positive direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
std::numeric_limits<tsl::float4_e2m1fn>::min(),
tsl::float4_e2m1fn(0)),
2);

// 1 non denorm from 0 in the negative direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
-std::numeric_limits<tsl::float4_e2m1fn>::min(),
tsl::float4_e2m1fn(0)),
2);

// a & b have different signs
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
-std::numeric_limits<tsl::float4_e2m1fn>::min(),
std::numeric_limits<tsl::float4_e2m1fn>::min()),
4);
}

TEST(FPDistanceTest, F8E3M4Distance) {
// a & b are equal
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
Expand Down
11 changes: 7 additions & 4 deletions xla/hlo/builder/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ XlaOp IsNegZero(XlaOp operand) {
case F32:
return Eq(BitcastConvertType(operand, U32),
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
case F4E2M1FN:
case F8E3M4:
case F8E4M3:
case F8E5M2:
Expand Down Expand Up @@ -971,8 +972,9 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
PrimitiveType a_x_type = a_shape.element_type();
bool needs_upcast = false;
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type :
{BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN,
F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down Expand Up @@ -1024,8 +1026,9 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) {
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
bool needs_upcast = false;
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type :
{BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN,
F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down
34 changes: 24 additions & 10 deletions xla/hlo/builder/lib/math_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,22 @@ class MathTypedTest : public MathTest {
Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)});

bool has_inf = std::numeric_limits<T>::has_infinity;
bool has_nan = std::numeric_limits<T>::has_quiet_NaN;
bool is_finite = !has_inf && !has_nan;
bool is_nan_only = !has_inf && has_nan;

auto expected = LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR1<bool>(
{true, true, true, true, true, false, false, false, false}),
LiteralUtil::CreateR1<bool>({true, true, true, true, true, is_finite,
is_finite, is_finite, is_finite}),
LiteralUtil::CreateR1<bool>({false, false, false, false, false, has_inf,
has_inf, false, false}),
LiteralUtil::CreateR1<bool>(
{false, false, false, false, false, has_inf, false, false, false}),
LiteralUtil::CreateR1<bool>(
{false, false, false, false, false, false, has_inf, false, false}),
LiteralUtil::CreateR1<bool>({false, false, false, false, false,
!has_inf, !has_inf, true, true}));
is_nan_only, is_nan_only, has_nan,
has_nan}));
ComputeAndCompareLiteral(&b, expected, {});
}

Expand All @@ -118,10 +123,11 @@ class MathTypedTest : public MathTest {
LiteralUtil::CreateR1<T>({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}),
&b));

bool is_mx = std::is_same_v<T, tsl::float4_e2m1fn>;
ComputeAndCompareLiteral(
&b,
LiteralUtil::CreateR1<bool>(
{has_negative_zero_v<T>, false, false, false, false, false, false}),
{has_negative_zero_v<T>, false, false, false, false, false, is_mx}),
{}, error_spec_);
}

Expand All @@ -136,6 +142,9 @@ class MathTypedTest : public MathTest {
// For good measure, we also check pow with an exponent other than 0.5.
void TestSqrtPowInequivalence() {
SetFastMathDisabled(true);
if (std::is_same_v<T, tsl::float4_e2m1fn>) {
GTEST_SKIP() << "Skipping due to low precision";
}

// Tests disable constant folding by default, but this test needs it
// enabled, otherwise we don't tickle the bug we're trying to catch.
Expand Down Expand Up @@ -181,18 +190,23 @@ class MathTypedTest : public MathTest {
&b);
Erf(x);

bool has_inf = std::numeric_limits<T>::has_infinity;
std::vector<T> expected = {
has_inf ? T(-1) : nan, has_inf ? T(1) : nan, T(-0), T(0), T(-1), T(1)};
bool inf_as_nan = !std::numeric_limits<T>::has_infinity &&
std::numeric_limits<T>::has_quiet_NaN;
std::vector<T> expected = {inf_as_nan ? nan : T(-1),
inf_as_nan ? nan : T(1),
T(-0),
T(0),
T(-1),
T(1)};

ComputeAndCompareR1<T>(&b, expected, {}, error_spec_);
}
};

using TestTypes =
::testing::Types<tsl::float8_e3m4, tsl::float8_e4m3, tsl::float8_e4m3fnuz,
tsl::float8_e4m3b11fnuz, tsl::float8_e5m2,
tsl::float8_e5m2fnuz,
::testing::Types<tsl::float4_e2m1fn, tsl::float8_e3m4, tsl::float8_e4m3,
tsl::float8_e4m3fnuz, tsl::float8_e4m3b11fnuz,
tsl::float8_e5m2, tsl::float8_e5m2fnuz,
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
Eigen::half,
#endif
Expand Down
3 changes: 3 additions & 0 deletions xla/hlo/transforms/simplifiers/float_normalization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ absl::Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
hlo->mutable_shape(), [&](Shape* subshape, const xla::ShapeIndex& index) {
if (subshape->element_type() == from) {
subshape->set_element_type(to);
if (subshape->has_layout() && from == F4E2M1FN) {
subshape->mutable_layout()->set_element_size_in_bits(0);
}
}
});
float_normalization_->UpdateLayout(hlo->mutable_shape());
Expand Down
4 changes: 3 additions & 1 deletion xla/hlo/transforms/simplifiers/float_normalization_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ class FloatNormalizationF8Test
public ::testing::WithParamInterface<PrimitiveType> {};

INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test,
::testing::Values(F8E3M4, F8E4M3, F8E5M2));
::testing::Values(F4E2M1FN, F8E3M4, F8E4M3,
F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ,
F8E5M2, F8E5M2FNUZ));

TEST_F(FloatNormalizationTest, NoopIfSupported) {
auto builder = HloComputation::Builder(TestName());
Expand Down
6 changes: 5 additions & 1 deletion xla/mlir/utils/type_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ absl::StatusOr<mlir::Type> ConvertPrimitiveTypeToMlirType(
switch (type) {
case xla::PrimitiveType::PRED:
return b.getI1Type();
case xla::PrimitiveType::F4E2M1FN:
return b.getFloat4E2M1FNType();
case xla::PrimitiveType::F8E5M2:
return b.getFloat8E5M2Type();
case xla::PrimitiveType::F8E4M3:
Expand Down Expand Up @@ -78,7 +80,9 @@ absl::StatusOr<mlir::Type> ConvertPrimitiveTypeToMlirType(
}

xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) {
if (type.isFloat8E5M2()) {
if (type.isFloat4E2M1FN()) {
return xla::PrimitiveType::F4E2M1FN;
} else if (type.isFloat8E5M2()) {
return xla::PrimitiveType::F8E5M2;
} else if (type.isFloat8E4M3()) {
return xla::PrimitiveType::F8E4M3;
Expand Down
1 change: 1 addition & 0 deletions xla/mlir/utils/type_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ INSTANTIATE_TEST_SUITE_P(
Execute, TypeUtilTest,
::testing::ValuesIn(std::vector<TypeUtilTestParam>(
{{PRED, [](mlir::Builder b) { return b.getI1Type(); }},
{F4E2M1FN, [](mlir::Builder b) { return b.getFloat4E2M1FNType(); }},
{F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }},
{F8E4M3, [](mlir::Builder b) { return b.getFloat8E4M3Type(); }},
{F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }},
Expand Down
7 changes: 7 additions & 0 deletions xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6844,6 +6844,13 @@ func.func @invalid_dimension_attr(%arg0: tensor<?x?xf32, #mhlo.type_extensions<b

// -----

func.func @f4e2m1fn(%arg0: tensor<f16>) -> tensor<f4E2M1FN> {
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f4E2M1FN>
func.return %0 : tensor<f4E2M1FN>
}

// -----

func.func @f8e3m4(%arg0: tensor<f16>) -> tensor<f8E3M4> {
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E3M4>
func.return %0 : tensor<f8E3M4>
Expand Down
Loading

0 comments on commit d7d5af7

Please sign in to comment.