Skip to content

Commit

Permalink
Float8E4M3FNUZ -> Float8E4M3FN for NVIDIA PTX
Browse files Browse the repository at this point in the history
Imported from openxla/triton#8

PiperOrigin-RevId: 665336874
  • Loading branch information
chsigg authored and copybara-github committed Aug 26, 2024
1 parent abfb042 commit 533e281
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 8 deletions.
272 changes: 272 additions & 0 deletions third_party/triton/temporary/fp8_fix.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
This patch can be removed as part of the next integrate.
The corresponding import patch has already been added.

==== triton/include/triton/Dialect/Triton/IR/TritonTypes.td#13 - triton/include/triton/Dialect/Triton/IR/TritonTypes.td ====
# action=edit type=text
--- triton/include/triton/Dialect/Triton/IR/TritonTypes.td 2024-06-07 05:28:31.000000000 -0700
+++ triton/include/triton/Dialect/Triton/IR/TritonTypes.td 2024-08-20 06:34:55.000000000 -0700
@@ -15,7 +15,7 @@
}

// Floating-point Type
-def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
+def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;

==== triton/lib/Analysis/Utility.cpp#42 - triton/lib/Analysis/Utility.cpp ====
# action=edit type=text
--- triton/lib/Analysis/Utility.cpp 2024-08-14 09:36:23.000000000 -0700
+++ triton/lib/Analysis/Utility.cpp 2024-08-20 06:34:55.000000000 -0700
@@ -425,6 +425,7 @@
if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth())
return false;

+ auto F8E4M3FN = TypeID::get<Float8E4M3FNType>();
auto F8E4M3FNUZ = TypeID::get<Float8E4M3FNUZType>();
auto F8E5M2FNUZ = TypeID::get<Float8E5M2FNUZType>();
auto F16 = TypeID::get<Float16Type>();
@@ -435,6 +436,7 @@
{F32, F32},
{F16, F16},
{BF16, BF16},
+ {F8E4M3FN, F8E4M3FN},
{F8E4M3FNUZ, F8E4M3FNUZ},
{F8E4M3FNUZ, F8E5M2FNUZ},
{F8E5M2FNUZ, F8E4M3FNUZ},
@@ -493,14 +495,14 @@
return false;
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 8 == 0 &&
- (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() ||
+ (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
return false;
}
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
if (op.getMaxNumImpreciseAcc() < 32 &&
- (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) &&
+ (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
return false;
}
==== triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp#20 - triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp ====
# action=edit type=text
--- triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp 2024-06-07 05:28:31.000000000 -0700
+++ triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp 2024-08-20 06:34:55.000000000 -0700
@@ -34,6 +34,9 @@
addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
+ addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> {
+ return IntegerType::get(type.getContext(), 8);
+ });
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
==== triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp#44 - triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp ====
# action=edit type=text
--- triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp 2024-07-31 01:05:00.000000000 -0700
+++ triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp 2024-08-20 06:40:32.000000000 -0700
@@ -382,7 +382,7 @@
NvidiaMmaEncodingAttr mmaLayout =
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
if (mmaLayout) {
- bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ();
+ bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
// promote operands for sm < 89 since fp8 mma is not natively supported
// promote operands for sm >= 90 when mma is not v3
if (!isNativeFP8 ||
==== triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp#39 - triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp ====
# action=edit type=text
--- triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2024-08-14 09:36:23.000000000 -0700
+++ triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2024-08-20 06:34:55.000000000 -0700
@@ -45,8 +45,9 @@
SmallVector<unsigned> validN;

// MMAv3 with larger instruction shape is preferred.
- if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() ||
- eltType.isF16() || eltType.isBF16() || eltType.isF32()) {
+ if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() ||
+ eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() ||
+ eltType.isF32()) {
validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176,
168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88,
80, 72, 64, 56, 48, 40, 32, 24, 16, 8});
==== triton/patches/public/fp8_fix.patch#None - triton/patches/public/fp8_fix.patch ====
# action=add type=text
--- /dev/null 1969-12-31 16:00:00.000000000 -0800
+++ triton/patches/public/fp8_fix.patch 2024-08-21 01:51:13.000000000 -0700
@@ -0,0 +1,2 @@
+triton/patches/public/fp8_fix.patch#1 - opened for add
+triton/patches/public/fp8_fix.patch - empty, assuming text.
==== triton/python/src/ir.cc#24 - triton/python/src/ir.cc ====
# action=edit type=text
--- triton/python/src/ir.cc 2024-08-12 00:24:31.000000000 -0700
+++ triton/python/src/ir.cc 2024-08-21 01:46:02.000000000 -0700
@@ -745,10 +745,8 @@
return self.getBuilder().getI64Type();
})
.def("get_fp8e4nv_ty",
- // TODO: fp8e4nv is using Float8E4M3FNUZType, which
- // does not seem right. It should use FloatE4M3FNType
[](TritonOpBuilder &self) -> Type {
- return self.getBuilder().getType<Float8E4M3FNUZType>();
+ return self.getBuilder().getType<Float8E4M3FNType>();
})
.def("get_fp8e4b8_ty",
[](TritonOpBuilder &self) -> Type {
==== triton/test/Conversion/tritongpu_to_llvm_hopper.mlir#25 - triton/test/Conversion/tritongpu_to_llvm_hopper.mlir ====
# action=edit type=text
--- triton/test/Conversion/tritongpu_to_llvm_hopper.mlir 2024-07-03 07:14:55.000000000 -0700
+++ triton/test/Conversion/tritongpu_to_llvm_hopper.mlir 2024-08-20 06:34:55.000000000 -0700
@@ -129,24 +129,24 @@
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_fp8_to_f16_conversion
tt.func @test_fp8_to_f16_conversion(
- %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FNUZ, #blocked>,
+ %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FN, #blocked>,
%in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) {
// CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
%out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked>
// CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
- %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked>
+ %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FN, #blocked> -> tensor<128xf16, #blocked>
// CHECK-COUNT-2: mul.rn.bf16x2
%out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked>

// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
%out3 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
- %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
+ %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FN, #blocked>

// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
%out5 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
- %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
+ %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FN, #blocked>
tt.return
}
}
==== triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp#4 - triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp ====
# action=edit type=text
--- triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp 2024-05-14 06:33:36.000000000 -0700
+++ triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp 2024-08-20 06:34:55.000000000 -0700
@@ -81,9 +81,9 @@
FP32_TF32_TF32_FP32,
FP16_FP16_FP16_FP16,
FP32_FP8E5M2_FP8E5M2_FP32,
- FP32_FP8E5M2_FP8E4M3FNUZ_FP32,
- FP32_FP8E4M3FNUZ_FP8E5M2_FP32,
- FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32,
+ FP32_FP8E5M2_FP8E4M3FN_FP32,
+ FP32_FP8E4M3FN_FP8E5M2_FP32,
+ FP32_FP8E4M3FN_FP8E4M3FN_FP32,
// integer tensor core instr
INT32_INT1_INT1_INT32, // Not implemented
INT32_INT4_INT4_INT32, // Not implemented
@@ -112,9 +112,9 @@
case TensorCoreType::FP16_FP16_FP16_FP16:
return fp16x2Pack2Ty;
case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32:
- case TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32:
- case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32:
- case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32:
+ case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32:
+ case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32:
+ case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32:
return fp32x4Ty;
case TensorCoreType::INT32_INT8_INT8_INT32:
return i32x4Ty;
@@ -140,14 +140,14 @@
bTy.getElementType().isFloat8E5M2())
return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32;
if (aTy.getElementType().isFloat8E5M2() &&
- bTy.getElementType().isFloat8E4M3FNUZ())
- return TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32;
- if (aTy.getElementType().isFloat8E4M3FNUZ() &&
+ bTy.getElementType().isFloat8E4M3FN())
+ return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32;
+ if (aTy.getElementType().isFloat8E4M3FN() &&
bTy.getElementType().isFloat8E5M2())
- return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32;
- if (aTy.getElementType().isFloat8E4M3FNUZ() &&
- bTy.getElementType().isFloat8E4M3FNUZ())
- return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32;
+ return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32;
+ if (aTy.getElementType().isFloat8E4M3FN() &&
+ bTy.getElementType().isFloat8E4M3FN())
+ return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32;
if (aTy.getElementType().isF32() && bTy.getElementType().isF32() &&
op.getInputPrecision() == InputPrecision::TF32)
return TensorCoreType::FP32_TF32_TF32_FP32;
@@ -193,11 +193,11 @@

{TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32,
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32"},
- {TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32,
+ {TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32,
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32"},
- {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32,
+ {TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32,
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32"},
- {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32,
+ {TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32,
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32"},
};

==== triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp#9 - triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp ====
# action=edit type=text
--- triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2024-06-07 05:28:31.000000000 -0700
+++ triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2024-08-20 06:34:55.000000000 -0700
@@ -58,7 +58,7 @@
return triton::nvgpu::WGMMAEltType::s8;
} else if (aTy.isFloat8E5M2()) {
return triton::nvgpu::WGMMAEltType::e5m2;
- } else if (aTy.isFloat8E4M3FNUZ()) {
+ } else if (aTy.isFloat8E4M3FN()) {
return triton::nvgpu::WGMMAEltType::e4m3;
} else {
llvm::report_fatal_error("Unsupported mma operand type found");
==== triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp#9 - triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp ====
# action=edit type=text
--- triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp 2024-07-17 02:05:59.000000000 -0700
+++ triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp 2024-08-20 06:34:55.000000000 -0700
@@ -386,7 +386,7 @@
std::pair<ConverterT, size_t>
getConversionFunc(Type srcTy, Type dstTy,
std::optional<RoundingMode> roundingMode) const {
- auto F8E4M3TyID = TypeID::get<Float8E4M3FNUZType>();
+ auto F8E4M3TyID = TypeID::get<Float8E4M3FNType>();
auto F8E5M2TyID = TypeID::get<Float8E5M2Type>();
auto F16TyID = TypeID::get<Float16Type>();
auto BF16TyID = TypeID::get<BFloat16Type>();
@@ -430,7 +430,7 @@
llvm::report_fatal_error("Unsupported rounding mode for conversion.");
}
if (computeCapability < 89 &&
- (srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) {
+ (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) {
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
"compute capability >= 89"
<< "\n";
@@ -452,7 +452,7 @@
auto dstElementType = getElementType(op.getResult());
auto roundingMode = op.getRounding();

- if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FNUZ()) {
+ if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) {
assert(roundingMode.has_value() &&
"Rounding mode must be specified for convertsions to fp8");

@@ -489,7 +489,7 @@

bool useFP16IntermediateSrc =
srcElementType.isF32() &&
- (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FNUZ() ||
+ (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() ||
dstElementType.isFloat8E5M2())) ||
roundingMode.value() == RoundingMode::RTZ);
bool isDstFP32 = dstElementType.isF32();
1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ those to this list.

temporary_patch_list = [
"//third_party/triton:temporary/highestPowOf2Divisor-underflow-fix.patch",
"//third_party/triton:temporary/fp8_fix.patch",
# Add new patches just above this line
]
5 changes: 1 addition & 4 deletions xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,7 @@ absl::StatusOr<Type> TritonType(mlir::OpBuilder b, PrimitiveType t) {
case F8E5M2:
return b.getFloat8E5M2Type();
case F8E4M3FN:
// TODO(b/345700241) Note that we return UZ type as Triton mistakenly uses
// this type for F8E4M3FN. The mapping must be changed when it's fixed in
// Triton.
return b.getFloat8E4M3FNUZType();
return b.getFloat8E4M3FNType();
default:
return absl::UnimplementedError(
absl::StrCat("This type is not supported yet: ",
Expand Down
8 changes: 4 additions & 4 deletions xla/service/gpu/tests/fp8_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ module attributes {

// CHECK-LABEL: e4m3_mapping
tt.func @e4m3_mapping(
%arg0: tensor<16x256xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>,
%arg1: tensor<256x16xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%arg0: tensor<16x256xf8E4M3FN, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>,
%arg1: tensor<256x16xf8E4M3FN, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
) {
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
// CHECK: mma.{{.*}}.e4m3.e4m3.f32
%res = tt.dot %arg0, %arg1, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32}
: tensor<16x256xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> *
tensor<256x16xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
: tensor<16x256xf8E4M3FN, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> *
tensor<256x16xf8E4M3FN, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
-> tensor<16x16xf32, #mma>
tt.return
}
Expand Down

0 comments on commit 533e281

Please sign in to comment.