-
Notifications
You must be signed in to change notification settings - Fork 456
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Float8E4M3FNUZ -> Float8E4M3FN for NVIDIA PTX
Imported from openxla/triton#8 PiperOrigin-RevId: 665336874
- Loading branch information
1 parent
abfb042
commit 4de78fc
Showing
4 changed files
with
278 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
This patch can be removed as part of the next integrate. | ||
The corresponding import patch has already been added. | ||
|
||
==== triton/include/triton/Dialect/Triton/IR/TritonTypes.td#13 - triton/include/triton/Dialect/Triton/IR/TritonTypes.td ==== | ||
# action=edit type=text | ||
--- triton/include/triton/Dialect/Triton/IR/TritonTypes.td 2024-06-07 05:28:31.000000000 -0700 | ||
+++ triton/include/triton/Dialect/Triton/IR/TritonTypes.td 2024-08-20 06:34:55.000000000 -0700 | ||
@@ -15,7 +15,7 @@ | ||
} | ||
|
||
// Floating-point Type | ||
-def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; | ||
+def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; | ||
def TT_FloatTensor : RankedTensorOf<[TT_Float]>; | ||
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; | ||
|
||
==== triton/lib/Analysis/Utility.cpp#42 - triton/lib/Analysis/Utility.cpp ==== | ||
# action=edit type=text | ||
--- triton/lib/Analysis/Utility.cpp 2024-08-14 09:36:23.000000000 -0700 | ||
+++ triton/lib/Analysis/Utility.cpp 2024-08-20 06:34:55.000000000 -0700 | ||
@@ -425,6 +425,7 @@ | ||
if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth()) | ||
return false; | ||
|
||
+ auto F8E4M3FN = TypeID::get<Float8E4M3FNType>(); | ||
auto F8E4M3FNUZ = TypeID::get<Float8E4M3FNUZType>(); | ||
auto F8E5M2FNUZ = TypeID::get<Float8E5M2FNUZType>(); | ||
auto F16 = TypeID::get<Float16Type>(); | ||
@@ -435,6 +436,7 @@ | ||
{F32, F32}, | ||
{F16, F16}, | ||
{BF16, BF16}, | ||
+ {F8E4M3FN, F8E4M3FN}, | ||
{F8E4M3FNUZ, F8E4M3FNUZ}, | ||
{F8E4M3FNUZ, F8E5M2FNUZ}, | ||
{F8E5M2FNUZ, F8E4M3FNUZ}, | ||
@@ -493,14 +495,14 @@ | ||
return false; | ||
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && | ||
retShapePerCTA[rank - 1] % 8 == 0 && | ||
- (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() || | ||
+ (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() || | ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || | ||
aElemTy.isF32()))) { | ||
return false; | ||
} | ||
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. | ||
if (op.getMaxNumImpreciseAcc() < 32 && | ||
- (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) && | ||
+ (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) && | ||
cast<RankedTensorType>(op.getType()).getElementType().isF32()) { | ||
return false; | ||
} | ||
==== triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp#20 - triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp ==== | ||
# action=edit type=text | ||
--- triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp 2024-06-07 05:28:31.000000000 -0700 | ||
+++ triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp 2024-08-20 06:34:55.000000000 -0700 | ||
@@ -34,6 +34,9 @@ | ||
addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional<Type> { | ||
return IntegerType::get(type.getContext(), 8); | ||
}); | ||
+ addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> { | ||
+ return IntegerType::get(type.getContext(), 8); | ||
+ }); | ||
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> { | ||
return IntegerType::get(type.getContext(), 8); | ||
}); | ||
==== triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp#44 - triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp ==== | ||
# action=edit type=text | ||
--- triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp 2024-07-31 01:05:00.000000000 -0700 | ||
+++ triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp 2024-08-20 06:40:32.000000000 -0700 | ||
@@ -382,7 +382,7 @@ | ||
NvidiaMmaEncodingAttr mmaLayout = | ||
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding()); | ||
if (mmaLayout) { | ||
- bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ(); | ||
+ bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); | ||
// promote operands for sm < 89 since fp8 mma is not natively supported | ||
// promote operands for sm >= 90 when mma is not v3 | ||
if (!isNativeFP8 || | ||
==== triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp#39 - triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp ==== | ||
# action=edit type=text | ||
--- triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2024-08-14 09:36:23.000000000 -0700 | ||
+++ triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2024-08-20 06:34:55.000000000 -0700 | ||
@@ -45,8 +45,9 @@ | ||
SmallVector<unsigned> validN; | ||
|
||
// MMAv3 with larger instruction shape is preferred. | ||
- if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() || | ||
- eltType.isF16() || eltType.isBF16() || eltType.isF32()) { | ||
+ if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() || | ||
+ eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() || | ||
+ eltType.isF32()) { | ||
validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, | ||
168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, | ||
80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); | ||
==== triton/patches/public/fp8_fix.patch#None - triton/patches/public/fp8_fix.patch ==== | ||
# action=add type=text | ||
--- /dev/null 1969-12-31 16:00:00.000000000 -0800 | ||
+++ triton/patches/public/fp8_fix.patch 2024-08-21 01:51:13.000000000 -0700 | ||
@@ -0,0 +1,2 @@ | ||
+triton/patches/public/fp8_fix.patch#1 - opened for add | ||
+triton/patches/public/fp8_fix.patch - empty, assuming text. | ||
==== triton/python/src/ir.cc#24 - triton/python/src/ir.cc ==== | ||
# action=edit type=text | ||
--- triton/python/src/ir.cc 2024-08-12 00:24:31.000000000 -0700 | ||
+++ triton/python/src/ir.cc 2024-08-21 01:46:02.000000000 -0700 | ||
@@ -745,10 +745,8 @@ | ||
return self.getBuilder().getI64Type(); | ||
}) | ||
.def("get_fp8e4nv_ty", | ||
- // TODO: fp8e4nv is using Float8E4M3FNUZType, which | ||
- // does not seem right. It should use FloatE4M3FNType | ||
[](TritonOpBuilder &self) -> Type { | ||
- return self.getBuilder().getType<Float8E4M3FNUZType>(); | ||
+ return self.getBuilder().getType<Float8E4M3FNType>(); | ||
}) | ||
.def("get_fp8e4b8_ty", | ||
[](TritonOpBuilder &self) -> Type { | ||
==== triton/test/Conversion/tritongpu_to_llvm_hopper.mlir#25 - triton/test/Conversion/tritongpu_to_llvm_hopper.mlir ==== | ||
# action=edit type=text | ||
--- triton/test/Conversion/tritongpu_to_llvm_hopper.mlir 2024-07-03 07:14:55.000000000 -0700 | ||
+++ triton/test/Conversion/tritongpu_to_llvm_hopper.mlir 2024-08-20 06:34:55.000000000 -0700 | ||
@@ -129,24 +129,24 @@ | ||
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { | ||
// CHECK-LABEL: test_fp8_to_f16_conversion | ||
tt.func @test_fp8_to_f16_conversion( | ||
- %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FNUZ, #blocked>, | ||
+ %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FN, #blocked>, | ||
%in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) { | ||
// CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> | ||
%out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked> | ||
// CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> | ||
- %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked> | ||
+ %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FN, #blocked> -> tensor<128xf16, #blocked> | ||
// CHECK-COUNT-2: mul.rn.bf16x2 | ||
%out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked> | ||
|
||
// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> | ||
%out3 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked> | ||
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> | ||
- %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> | ||
+ %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FN, #blocked> | ||
|
||
// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> | ||
%out5 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked> | ||
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> | ||
- %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> | ||
+ %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FN, #blocked> | ||
tt.return | ||
} | ||
} | ||
==== triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp#4 - triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp ==== | ||
# action=edit type=text | ||
--- triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp 2024-05-14 06:33:36.000000000 -0700 | ||
+++ triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp 2024-08-20 06:34:55.000000000 -0700 | ||
@@ -81,9 +81,9 @@ | ||
FP32_TF32_TF32_FP32, | ||
FP16_FP16_FP16_FP16, | ||
FP32_FP8E5M2_FP8E5M2_FP32, | ||
- FP32_FP8E5M2_FP8E4M3FNUZ_FP32, | ||
- FP32_FP8E4M3FNUZ_FP8E5M2_FP32, | ||
- FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32, | ||
+ FP32_FP8E5M2_FP8E4M3FN_FP32, | ||
+ FP32_FP8E4M3FN_FP8E5M2_FP32, | ||
+ FP32_FP8E4M3FN_FP8E4M3FN_FP32, | ||
// integer tensor core instr | ||
INT32_INT1_INT1_INT32, // Not implemented | ||
INT32_INT4_INT4_INT32, // Not implemented | ||
@@ -112,9 +112,9 @@ | ||
case TensorCoreType::FP16_FP16_FP16_FP16: | ||
return fp16x2Pack2Ty; | ||
case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32: | ||
- case TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32: | ||
- case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32: | ||
- case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32: | ||
+ case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32: | ||
+ case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32: | ||
+ case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32: | ||
return fp32x4Ty; | ||
case TensorCoreType::INT32_INT8_INT8_INT32: | ||
return i32x4Ty; | ||
@@ -140,14 +140,14 @@ | ||
bTy.getElementType().isFloat8E5M2()) | ||
return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; | ||
if (aTy.getElementType().isFloat8E5M2() && | ||
- bTy.getElementType().isFloat8E4M3FNUZ()) | ||
- return TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32; | ||
- if (aTy.getElementType().isFloat8E4M3FNUZ() && | ||
+ bTy.getElementType().isFloat8E4M3FN()) | ||
+ return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; | ||
+ if (aTy.getElementType().isFloat8E4M3FN() && | ||
bTy.getElementType().isFloat8E5M2()) | ||
- return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32; | ||
- if (aTy.getElementType().isFloat8E4M3FNUZ() && | ||
- bTy.getElementType().isFloat8E4M3FNUZ()) | ||
- return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32; | ||
+ return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; | ||
+ if (aTy.getElementType().isFloat8E4M3FN() && | ||
+ bTy.getElementType().isFloat8E4M3FN()) | ||
+ return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; | ||
if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && | ||
op.getInputPrecision() == InputPrecision::TF32) | ||
return TensorCoreType::FP32_TF32_TF32_FP32; | ||
@@ -193,11 +193,11 @@ | ||
|
||
{TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32, | ||
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32"}, | ||
- {TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32, | ||
+ {TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32, | ||
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32"}, | ||
- {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32, | ||
+ {TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32, | ||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32"}, | ||
- {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32, | ||
+ {TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32, | ||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32"}, | ||
}; | ||
|
||
==== triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp#9 - triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp ==== | ||
# action=edit type=text | ||
--- triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2024-06-07 05:28:31.000000000 -0700 | ||
+++ triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2024-08-20 06:34:55.000000000 -0700 | ||
@@ -58,7 +58,7 @@ | ||
return triton::nvgpu::WGMMAEltType::s8; | ||
} else if (aTy.isFloat8E5M2()) { | ||
return triton::nvgpu::WGMMAEltType::e5m2; | ||
- } else if (aTy.isFloat8E4M3FNUZ()) { | ||
+ } else if (aTy.isFloat8E4M3FN()) { | ||
return triton::nvgpu::WGMMAEltType::e4m3; | ||
} else { | ||
llvm::report_fatal_error("Unsupported mma operand type found"); | ||
==== triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp#9 - triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp ==== | ||
# action=edit type=text | ||
--- triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp 2024-07-17 02:05:59.000000000 -0700 | ||
+++ triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp 2024-08-20 06:34:55.000000000 -0700 | ||
@@ -386,7 +386,7 @@ | ||
std::pair<ConverterT, size_t> | ||
getConversionFunc(Type srcTy, Type dstTy, | ||
std::optional<RoundingMode> roundingMode) const { | ||
- auto F8E4M3TyID = TypeID::get<Float8E4M3FNUZType>(); | ||
+ auto F8E4M3TyID = TypeID::get<Float8E4M3FNType>(); | ||
auto F8E5M2TyID = TypeID::get<Float8E5M2Type>(); | ||
auto F16TyID = TypeID::get<Float16Type>(); | ||
auto BF16TyID = TypeID::get<BFloat16Type>(); | ||
@@ -430,7 +430,7 @@ | ||
llvm::report_fatal_error("Unsupported rounding mode for conversion."); | ||
} | ||
if (computeCapability < 89 && | ||
- (srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) { | ||
+ (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) { | ||
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " | ||
"compute capability >= 89" | ||
<< "\n"; | ||
@@ -452,7 +452,7 @@ | ||
auto dstElementType = getElementType(op.getResult()); | ||
auto roundingMode = op.getRounding(); | ||
|
||
- if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FNUZ()) { | ||
+ if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { | ||
assert(roundingMode.has_value() && | ||
"Rounding mode must be specified for convertsions to fp8"); | ||
|
||
@@ -489,7 +489,7 @@ | ||
|
||
bool useFP16IntermediateSrc = | ||
srcElementType.isF32() && | ||
- (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FNUZ() || | ||
+ (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() || | ||
dstElementType.isFloat8E5M2())) || | ||
roundingMode.value() == RoundingMode::RTZ); | ||
bool isDstFP32 = dstElementType.isF32(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters