Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Float8E4M3FNUZ -> Float8E4M3FN for NVIDIA PTX #8

Closed
wants to merge 10 commits into from
888 changes: 888 additions & 0 deletions BUILD

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4713bd4ccc0c0d568f92916e7851d993291742c0
4c5ef6690040383956461828457ac27f7f912edb
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TritonTypeDef<string name, string _mnemonic, list<Trait> traits = []>
}

// Floating-point Type
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should F8E4M3FNUZ really need to be removed?

Also below, in some places F8E4M3FN got added, in other places F8E4M3FNUZ is being replaced. It would be good to explain what we want in the PR description and apply it consistently. Or maybe I'm missing something and this is all intentional?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

F8E4M3FNUZ needs to be listed here, as we want likely want support for it on other platforms (although I haven't tested this).

In places where it is NVIDIA PTX targetted, we replace F8E4M3FNUZ with F8E4M3FN, and in other "generic" places we add F8E4M3FN

def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;

Expand Down
6 changes: 4 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ bool supportMFMATypes(Type a, Type b) {
if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth())
return false;

auto F8E4M3FN = TypeID::get<Float8E4M3FNType>();
auto F8E4M3FNUZ = TypeID::get<Float8E4M3FNUZType>();
auto F8E5M2FNUZ = TypeID::get<Float8E5M2FNUZType>();
auto F16 = TypeID::get<Float16Type>();
Expand All @@ -435,6 +436,7 @@ bool supportMFMATypes(Type a, Type b) {
{F32, F32},
{F16, F16},
{BF16, BF16},
{F8E4M3FN, F8E4M3FN},
{F8E4M3FNUZ, F8E4M3FNUZ},
{F8E4M3FNUZ, F8E5M2FNUZ},
{F8E5M2FNUZ, F8E4M3FNUZ},
Expand Down Expand Up @@ -493,14 +495,14 @@ bool supportMMA(triton::DotOp op, int version) {
return false;
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 8 == 0 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() ||
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
return false;
}
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
if (op.getMaxNumImpreciseAcc() < 32 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
return false;
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
auto ouEltTy = ouTensorTy.getElementType();
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
if ((inBitWidth == 16 && ouBitWidth == 32) ||
(inBitWidth == 32 && ouBitWidth == 16)) {
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ struct ArithConstantSplatOpConversion
// LLVM IR.
if (type::isFloat8(elemType))
elemType = rewriter.getIntegerType(8);
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
auto typeConverter = getTypeConverter();
auto constOp = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(elemType), val);
auto llStruct = SplatOpConversion::convertSplatLikeOp(
elemType, op.getType(), constOp, typeConverter, rewriter, loc);
rewriter.replaceOp(op, llStruct);
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2721,6 +2721,11 @@ struct CanonicalizeConvertFromAlloc
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding, so we want to keep this layout conversion.
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
convert.getSrc().getType().getEncoding()))
return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
op, op->getResult(0).getType(), convert.getSrc());
return mlir::success();
Expand Down
26 changes: 25 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
newLayout, SharedMemorySpace);
rewriter.setInsertionPointAfterValue(arg);

// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding.
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
argType.getEncoding())) {
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
// then pass it to the LocalAllocOp.
auto newArgType = RankedTensorType::get(
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
auto dotOperandToBlockedCvt =
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
dotOperandToBlockedCvt);
}

return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
}

Expand All @@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;

static bool bwdFilter(Operation *op) {
// Dot operand layout assignment to Predicates are not currently supported
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
// condition limits visibility of the original bit-width so that predicate
// are not considered, hence, kwidth can never be = 32.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
return false;
}
return op->getNumOperands() == 1 &&
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
isPureUnaryInlineAsm(op) ||
Expand Down Expand Up @@ -357,7 +381,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
NvidiaMmaEncodingAttr mmaLayout =
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
if (mmaLayout) {
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ();
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
// promote operands for sm < 89 since fp8 mma is not natively supported
// promote operands for sm >= 90 when mma is not v3
if (!isNativeFP8 ||
Expand Down
17 changes: 16 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
PatternRewriter &rewriter) const override {
// Only consider conversions to dot operand.
auto cvtTy = cast<RankedTensorType>(cvt.getType());
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
if (!dotOpEnc)
return failure();

auto src = cvt.getSrc().getDefiningOp();
Expand All @@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
[](Type ty) { return isa<RankedTensorType>(ty); }))
return failure();

// Quick handling to fix loading issues when computing the original
// bitwidth is unable to realize that there is a mixed-precision dot
// (hence kWidth = 1) but wants to hoist through the type conversion.
if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
return failure();

// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
Expand All @@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
return failure();

// Don't hoist through u1 -> fp casts as they aren't supported in
// ElementwiseOpToLLVM::reorderValues().
if (isa<arith::UIToFPOp>(src)) {
Type srcType = getElementTypeOrSelf(src->getOperand(0));
if (srcType.isInteger(1))
return failure();
}

// Check that the conversion is transitively dependent on a load, and all
// operations between the load and the conversion are layout preserving.
//
Expand Down
17 changes: 16 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
type.getMemorySpace()),
v, offsetsVal);

// We need to assign kwidth to zero in the case where the parent layout is
// Blocked, otherwise the verifier emits a failure. The parent layout is
// Blocked only when Tensor Cores are disabled.
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
? 0
: prefetchWidth / 8;
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
builder.getContext(), opIdx, dotEncoding, kwidth);
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
Expand Down Expand Up @@ -187,6 +193,15 @@ LogicalResult Prefetcher::initialize() {
break;
if (!op->getResult(0).hasOneUse())
break;
// Similar to issues faced in HoistLayoutConversion pattern in
// OptimizeDotOperands.cpp, we can't propagate through type casts from
// predicates as they aren't supported in Triton when encoded with dot_op
// layout.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
break;
}
rets.push_back(op->getOperand(0));
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
foundConvertFromShared = true;
Expand Down
46 changes: 1 addition & 45 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,6 @@ class LayoutRematerialization {
ConvertLayoutOp convertOp);

private:
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
// Existing tuples of (value, layout) that needs to be updated when recreating
// scf ops. This prevents keeping track of Values that have been delete when
// rewriting slices.
DenseMap<Value, Attribute> mappedValues;
// map of the values remat based on encoding.
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
Expand All @@ -154,7 +149,6 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
Value newV) {
LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
rematMapping[{old, encoding}] = newV;
mappedValues[old] = encoding;
}

// Remove unneeded values now that we are done with the rematMapping.
Expand Down Expand Up @@ -807,31 +801,6 @@ bool canBeRemat(Operation *op) {
return true;
}

void LayoutRematerialization::updateRematMapping(
SmallVector<std::tuple<Value, Value>> &values) {
for (auto [old, newV] : values) {
auto it = mappedValues.find(old);
if (it != mappedValues.end()) {
Attribute encoding = it->second;
auto rematIt = rematMapping.find({old, it->second});
assert(rematIt != rematMapping.end());
Value replacedValue = rematIt->second;
rematMapping.erase(rematIt);
mappedValues.erase(it);
// Loop through the replacement value to find the new version of remat
// value. This should be okay as the number of values should be small.
for (auto [before, after] : values) {
if (before == replacedValue) {
replacedValue = after;
break;
}
}
rematMapping[{newV, encoding}] = replacedValue;
mappedValues[newV] = encoding;
}
}
}

void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp,
Expand All @@ -844,14 +813,6 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
// for/yield to fall out of sync
SetVector<Value> valuesWithExistingRemat;
for (Value v : slice) {
auto layoutIt = layout.find(v);
assert(layoutIt != layout.end());
// If we already have a remat value for this value, use it.
if (hasRematValue(v, layoutIt->second)) {
mapping.map(v, getRematValue(v, layoutIt->second));
valuesWithExistingRemat.insert(v);
continue;
}
if (v.getDefiningOp()) {
opsToRewrite.insert(v.getDefiningOp());
if (auto ifOp = v.getDefiningOp<scf::IfOp>()) {
Expand Down Expand Up @@ -941,8 +902,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
if (slice.count(res)) {
// Why can't we use res instead of ifOp.getResult(oldIdx)?
mapping.map(ifOp.getResult(oldIdx), newIfOp.getResult(newIdx));
addRematValue(ifOp.getResult(oldIdx), layout[res],
newIfOp.getResult(newIdx));
addRematValue(res, layout[res], newIfOp.getResult(newIdx));
++newIdx;
}
++oldIdx;
Expand Down Expand Up @@ -973,8 +933,6 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
auto cvt = builder.create<ConvertLayoutOp>(op->getLoc(), newType,
newOp->getResult(0));
mapping.map(op->getResult(0), cvt.getResult());
addRematValue(op->getResult(0), layout[op->getResult(0)],
cvt.getResult());
continue;
}
Operation *newOp = builder.clone(*op, mapping);
Expand All @@ -986,14 +944,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
cast<RankedTensorType>(old.getType()).getShape(),
cast<RankedTensorType>(old.getType()).getElementType(), it->second);
newV.setType(newType);
addRematValue(old, it->second, newV);
}
}
// Check mapping and see if there are existing convertOps on the old Argument
convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc()));
opToDelete.insert(convertOp);

updateRematMapping(replacements);
for (auto &kv : replacements) {
builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv));
}
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
SmallVector<unsigned> validN;

// MMAv3 with larger instruction shape is preferred.
if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() ||
eltType.isF16() || eltType.isBF16() || eltType.isF32()) {
if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() ||
eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() ||
eltType.isF32()) {
validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176,
168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88,
80, 72, 64, 56, 48, 40, 32, 24, 16, 8});
Expand Down
76 changes: 76 additions & 0 deletions python/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# NOTE: Do not depend on any targets from this directory,
# but use //third_party/py/triton instead.

load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

package(
default_applicable_licenses = ["//:license"],
default_visibility = [
"//third_party/py/triton:__pkg__",
"@triton//python:__subpackages__",
],
)

cc_library(
name = "passes",
hdrs = ["src/passes.h"],
includes = ["src"],
visibility = ["@triton//third_party:__subpackages__"],
)

pybind_extension(
name = "libtriton",
srcs = [
"src/interpreter.cc",
"src/ir.cc",
"src/llvm.cc",
"src/main.cc",
"src/passes.cc",
],
copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"],
deps = [
":passes",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:IPO",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:InstCombine",
"@llvm-project//llvm:Linker",
"@llvm-project//llvm:MC",
"@llvm-project//llvm:Passes",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:IndexDialect",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:ToLLVMIRTranslation",
"@llvm-project//mlir:Transforms",
"//:TritonAnalysis",
"//:TritonDialects",
"//:TritonGPUToLLVM",
"//:TritonGPUTransforms",
"//:TritonHSACO",
"//:TritonLLVMIR",
"//:TritonNvidiaGPUTransforms",
"//:TritonPTX",
"//:TritonToTritonGPU",
"//:TritonTools",
"//:TritonTransforms",
"@triton//third_party/nvidia:triton_nvidia",
],
)

filegroup(
name = "files",
srcs = glob(
include = ["triton/**/*.py"],
),
)
Loading