Skip to content

Commit

Permalink
Addressing PR#19096 review comments (round 2)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-kozub committed Dec 18, 2024
1 parent 1e34417 commit d4de0a3
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 129 deletions.
99 changes: 0 additions & 99 deletions third_party/tsl/third_party/py/ml_dtypes/e8m0.patch

This file was deleted.

1 change: 0 additions & 1 deletion third_party/tsl/third_party/py/ml_dtypes/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def repo():
tf_http_archive(
name = "ml_dtypes",
build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD",
patch_file = ["//third_party/py/ml_dtypes:e8m0.patch"],
link_files = {
"//third_party/py/ml_dtypes:ml_dtypes.tests.BUILD": "tests/BUILD.bazel",
"//third_party/py/ml_dtypes:LICENSE": "LICENSE",
Expand Down
3 changes: 3 additions & 0 deletions xla/pjrt/c/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# PJRT C API changelog

## 0.60
* Added types F4E2M1FN and F8E8M0FNU.

## 0.59
* Added ``PJRT_MemoryDescriptions_Extension``.

Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 59
#define PJRT_API_MINOR 60

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down
2 changes: 1 addition & 1 deletion xla/primitive_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
<< primitive_util::LowercasePrimitiveTypeName(to_type);
}
}
}
} // NOLINT(readability/fn_size)

} // namespace
} // namespace xla
6 changes: 4 additions & 2 deletions xla/service/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1436,15 +1436,17 @@ absl::StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
// Cast to F16 first. Casts to F4E2M1FN must be from F16.
if (from_type != F16) {
operand_value = b_->CreateFPCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_));
operand_value,
llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext()));
}
return EmitF16ToF4e2m1fn(operand_value, b_);
}
if (to_type == F8E8M0FNU) {
// Cast to F32 first. Casts to F8E8M0FNU must be from F32.
if (from_type != F32) {
operand_value = b_->CreateFPCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(F32, module_));
operand_value,
llvm_ir::PrimitiveTypeToIrType(F32, module_->getContext()));
}
return EmitF32ToF8e8m0fnu(operand_value, b_);
}
Expand Down
34 changes: 18 additions & 16 deletions xla/service/gpu/fusions/triton/triton_support_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,18 @@ INSTANTIATE_TEST_SUITE_P(

using ReduceTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam;

static std::string_view init_value(PrimitiveType dtype) {
if (dtype == C64 || dtype == C128) {
return "(0, 0)";
} else if (dtype == F8E8M0FNU) {
return "1e-40";
} else {
return "0";
}
}

TEST_P(ReduceTest, IsTritonSupportedReduction) {
auto [data_type, opcode, cc] = GetParam();
bool dtype_is_complex = data_type == C64 || data_type == C128;
const std::string kHloTestTemplate =
absl::Substitute(R"(
add {
Expand All @@ -567,7 +576,7 @@ ENTRY triton_computation {
ROOT reduce = $0[125] reduce(parameter_0, constant_0),
dimensions={1}, to_apply=add
})",
"$0", dtype_is_complex ? "(0, 0)" : "0");
"$0", init_value(data_type));
TF_ASSERT_OK_AND_ASSIGN(
TestedInstruction ti,
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));
Expand Down Expand Up @@ -599,7 +608,6 @@ TEST_P(
ReduceTest,
UnsupportedReduceWithMoreThanOneReduceDimensionsFailsGracefullyWithTriton) {
auto [data_type, opcode, cc] = GetParam();
bool dtype_is_complex = data_type == C64 || data_type == C128;
const std::string kHloTestTemplate =
absl::Substitute(R"(
add {
Expand All @@ -614,7 +622,7 @@ ENTRY triton_computation {
ROOT reduce = $0[2] reduce(parameter_0, constant_0),
dimensions={1,2}, to_apply=add
})",
"$0", dtype_is_complex ? "(0, 0)" : "0");
"$0", init_value(data_type));
TF_ASSERT_OK_AND_ASSIGN(
TestedInstruction ti,
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));
Expand All @@ -624,7 +632,6 @@ ENTRY triton_computation {

TEST_P(ReduceTest, IsTritonSupportedReduceWithNonLastReduceDimension) {
auto [data_type, opcode, cc] = GetParam();
bool dtype_is_complex = data_type == C64 || data_type == C128;
const std::string kHloTestTemplate =
absl::Substitute(R"(
add {
Expand All @@ -638,7 +645,7 @@ ENTRY triton_computation {
constant_0 = $0[] constant($1)
ROOT reduce = $0[127] reduce(parameter_0, constant_0), dimensions={0}, to_apply=add
})",
"$0", dtype_is_complex ? "(0, 0)" : "0");
"$0", init_value(data_type));
TF_ASSERT_OK_AND_ASSIGN(
TestedInstruction ti,
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));
Expand All @@ -649,7 +656,6 @@ ENTRY triton_computation {
TEST_P(ReduceTest,
UnsupportedReduceWithMoreThanOneOperandsFailsGracefullyWithTriton) {
auto [data_type, opcode, cc] = GetParam();
bool dtype_is_complex = data_type == C64 || data_type == C128;
const std::string kHloTestTemplate =
absl::Substitute(R"(
add {
Expand All @@ -670,7 +676,7 @@ ENTRY triton_computation {
dimensions={1}, to_apply=add
ROOT reduce = $0[125] get-tuple-element(tuple), index=0
})",
"$0", dtype_is_complex ? "(0, 0)" : "0");
"$0", init_value(data_type));
TF_ASSERT_OK_AND_ASSIGN(
TestedInstruction ti,
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));
Expand Down Expand Up @@ -701,7 +707,6 @@ ENTRY triton_computation {

TEST_P(ReduceTest, UnsupportedReductionComputationFailsGracefullyWithTriton) {
auto [data_type, opcode, cc] = GetParam();
bool dtype_is_complex = data_type == C64 || data_type == C128;
const std::string kHloTestTemplate =
absl::Substitute(R"(
custom_call {
Expand All @@ -716,7 +721,7 @@ ENTRY triton_computation {
ROOT reduce = $0[125] reduce(parameter_0, constant_0),
dimensions={1}, to_apply=custom_call
})",
"$0", dtype_is_complex ? "(0, 0)" : "0");
"$0", init_value(data_type));
TF_ASSERT_OK_AND_ASSIGN(
TestedInstruction ti,
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));
Expand All @@ -740,7 +745,6 @@ using ReductionComputationTest =
// computation and in regular HLO. See triton_support.cc for more details.
TEST_P(ReductionComputationTest, DifferentBinaryOps) {
auto [data_type, opcode, cc] = GetParam();
bool dtype_is_complex = data_type == C64 || data_type == C128;
const std::string kHloTestTemplate = absl::Substitute(
R"(
reduce_computation {
Expand All @@ -755,7 +759,7 @@ ENTRY triton_computation {
ROOT reduce = $0[125] reduce(parameter_0, constant_0),
dimensions={1}, to_apply=reduce_computation
})",
"$0", HloOpcodeString(opcode), dtype_is_complex ? "(0, 0)" : "0");
"$0", HloOpcodeString(opcode), init_value(data_type));

TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti,
ParseTemplateAndGetInstruction(
Expand Down Expand Up @@ -1115,13 +1119,12 @@ TEST_P(ConstantTest, ConstantEffectiveScalar) {
// The IsTritonSupportedReduction effectively tests the scalar constant
// support.
auto [data_type, cc] = GetParam();
bool dtype_is_complex = data_type == C64 || data_type == C128;
const std::string kHloTestTemplate =
absl::Substitute(R"(
ENTRY triton_computation {
ROOT const = $0[1,1] constant({{$1}})
})",
"$0", dtype_is_complex ? "(0, 0)" : "0");
"$0", init_value(data_type));

TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction(
kHloTestTemplate, data_type,
Expand All @@ -1133,13 +1136,12 @@ TEST_P(ConstantTest, Constant2D) {
// The IsTritonSupportedReduction effectively tests the scalar constant
// support.
auto [data_type, cc] = GetParam();
bool dtype_is_complex = data_type == C64 || data_type == C128;
const std::string kHloTestTemplate =
absl::Substitute(R"(
ENTRY triton_computation {
ROOT const = $0[3,3] constant({{$1,$1,$1},{$1,$1,$1},{$1,$1,$1}})
})",
"$0", dtype_is_complex ? "(0, 0)" : "0");
"$0", init_value(data_type));

TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction(
kHloTestTemplate, data_type,
Expand Down
18 changes: 9 additions & 9 deletions xla/tests/convert_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2146,15 +2146,15 @@ XLA_TYPED_TEST(ConvertTestT, ConvertF8e8m0fnuRoundtripExhaustive) {
}

XLA_TYPED_TEST(ConvertTestT, ConvertF8e8m0fnuRoundtripExhaustive2) {
#ifdef XLA_TEST_BACKEND_CPU
// This test is disabled on CPU, as converting 0x1p-127 from double to float
// using CVTSD2SS on x64 results in an underflow (even though the result is
// representable as denormalized float32).
if (std::is_same_v<TypeParam, double>) {
GTEST_SKIP() << "Skipping test for double precision floating point that "
"loses denormal value during conversion";
}
#endif
if (this->client_->platform()->Name() == "Host") {
// This test is disabled on CPU, as converting 0x1p-127 from double to float
// using CVTSD2SS on x64 results in an underflow (even though the result is
// representable as denormalized float32).
if (std::is_same_v<TypeParam, double>) {
GTEST_SKIP() << "Skipping test for double precision floating point that "
"loses denormal value during conversion";
}
}
// Convert from supported floating point type to FP8.
XlaBuilder builder(this->TestName());

Expand Down

0 comments on commit d4de0a3

Please sign in to comment.