Skip to content

Commit

Permalink
[AMD] Refactor mfma selection (triton-lang#3244)
Browse files Browse the repository at this point in the history
This PR refactors the logic of mfma instruction selection. It brings
everything from ROCm#441 and parts of
ROCm#469 so that we should have full
support of mfma32 and mfma16 with all types. But support for mfma4 is
not complete yet. We leave it to future PRs.
Also in a future PR, we'll add tests for AMD f8 inputs.
  • Loading branch information
zhanglx13 authored Mar 1, 2024
1 parent a5c7bd6 commit 5454905
Show file tree
Hide file tree
Showing 21 changed files with 568 additions and 569 deletions.
18 changes: 9 additions & 9 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,10 @@ using LLVM::getMultiDimIndex;
using LLVM::SharedMemoryObject;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::SharedMemoryObject;
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::CTALayoutAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::MfmaEncodingAttr;
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;

Expand Down Expand Up @@ -745,14 +745,14 @@ emitOffsetForMmaLayoutV3(const NvidiaMmaEncodingAttr &mmaLayout,

static SmallVector<Value>
emitBaseIndexForMfmaLayout(Location loc, RewriterBase &rewriter,
const MfmaEncodingAttr &mfmaLayout,
const AMDMfmaEncodingAttr &mfmaLayout,
RankedTensorType type) {
auto shape = type.getShape();
auto _warpsPerCTA = mfmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
i32_val(_warpsPerCTA[1])};
int nonKDim = mfmaLayout.getNonKDim();
int nonKDim = mfmaLayout.getMDim();

Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout));
Expand Down Expand Up @@ -785,10 +785,10 @@ emitBaseIndexForMfmaLayout(Location loc, RewriterBase &rewriter,
return multiDimBase;
}

static void emitMfmaOffsetForCTA(const MfmaEncodingAttr &mfmaLayout,
static void emitMfmaOffsetForCTA(const AMDMfmaEncodingAttr &mfmaLayout,
SmallVector<SmallVector<unsigned>> &offsets,
unsigned ctaOffsetX, unsigned ctaOffsetY) {
auto nonKDim = mfmaLayout.getNonKDim();
auto nonKDim = mfmaLayout.getMDim();
// MFMA output tile consists of repeated "dot operand B" layout groups along
// row axis. This variable defines number of these groups.
const unsigned numGroups = (nonKDim == 32 ? 4 : 1);
Expand All @@ -813,7 +813,7 @@ static void emitMfmaOffsetForCTA(const MfmaEncodingAttr &mfmaLayout,
}

static SmallVector<SmallVector<unsigned>>
emitOffsetForMfmaLayout(const MfmaEncodingAttr &mfmaLayout,
emitOffsetForMfmaLayout(const AMDMfmaEncodingAttr &mfmaLayout,
RankedTensorType type) {
auto tensorShape = type.getShape();
SmallVector<SmallVector<unsigned>> offsets;
Expand All @@ -824,7 +824,7 @@ emitOffsetForMfmaLayout(const MfmaEncodingAttr &mfmaLayout,
for (unsigned d = 0; d < 2; ++d) {
unsigned inPerCTA = std::min<unsigned>(tensorShape[d], shapePerCTA[d]);
unsigned inPerWarp = ceil<unsigned>(inPerCTA, warpsPerCTA[d]);
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, mfmaLayout.getNonKDim());
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, mfmaLayout.getMDim());
}

for (unsigned i = 0; i < numWarpsPerDim[0]; ++i) {
Expand Down Expand Up @@ -917,7 +917,7 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, Attribute layout,
if (mmaLayout.isAmpere() || mmaLayout.isHopper())
result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter, mmaLayout,
type);
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
} else if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) {
result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
Expand Down Expand Up @@ -953,7 +953,7 @@ emitOffsetForLayout(Attribute layout, RankedTensorType type) {
if (mmaLayout.isHopper())
return emitOffsetForMmaLayoutV3(mmaLayout, type);
}
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) {
return emitOffsetForMfmaLayout(mfmaLayout, type);
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>())
Expand Down
51 changes: 42 additions & 9 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ compared to 1*64 when the hasLeadingOffset is false.
"bool":$needTrans), [{

// ---- begin GFX908/GFX90A ----
auto mfmaEnc = dotOpEnc.getParent().dyn_cast<MfmaEncodingAttr>();
auto mfmaEnc = dotOpEnc.getParent().dyn_cast<AMDMfmaEncodingAttr>();

if (mfmaEnc) {
int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0;
Expand Down Expand Up @@ -259,6 +259,10 @@ compared to 1*64 when the hasLeadingOffset is false.
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
int maxPhase = SIMDWidth / perPhase;

// TODO (zhanglx): figure out better parameters for mfma4
if (mfmaEnc.getMDim() == 4 )
maxPhase = 4;

return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
} else {
// Do not swizzle in case k dimension is not innermost.
Expand Down Expand Up @@ -700,16 +704,23 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
];
}

def MfmaEncodingAttr : DistributedEncoding<"MfmaEncoding", "mfma_encoding", [MmaEncodingTrait]> {
let mnemonic = "mfma";
def AMDMfmaEncodingAttr : DistributedEncoding<"AMDMfmaEncoding", "amd_mfma_encoding", [MmaEncodingTrait]> {
let mnemonic = "amd_mfma";

let description = [{
An encoding for tensors that have been produced by MI100 && MI200 tensor cores.
It is characterized by parameters `warpsPerCTA` and `nonKDim` that indicates how data should be partitioned
between waves (analogous to the term 'warp' used in NVIDIA's CUDA programming model).
An encoding for tensors that have been produced by tensor cores of AMD MI GPUs.
It is characterized by the following parameters:
- `versionMajor` and `versionMinor` indicates the GPU arch
- 1.0: gfx908, i.e. MI100
- 2.0: gfx90a: i.e. MI200, MI210, MI250
- 3.0: gfx940, gfx941, gfx942: MI300
- `warpsPerCTA` indicates the wave layout in the workgroup.
- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction.
- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout
without going to LDS. This is used in the case of chained dot (E.g. Flash-Attention kernel).

Example 1:
Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and nonKDim set to 32.
Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32.
The data will be distributed between threads as follows:

wave 0 wave 1
Expand Down Expand Up @@ -748,7 +759,7 @@ The data will be distributed between threads as follows:
[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ]

Example 2:
Suppose we have a tensor with a shape of [16, 32], warpsPerCTA set to [1, 2] and nonKDim set to 16.
Suppose we have a tensor with a shape of [16, 32], warpsPerCTA set to [1, 2] and MDim=NDim=16.
The data will be distributed between threads as follows:

wave 0 wave 1
Expand All @@ -769,12 +780,34 @@ The data will be distributed between threads as follows:
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]

Example 3:
Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and nonKDim set to 4.
The data will be distributed between threads as follows(note that each element is duploicated in 16 threads):
Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and MDim=NDim=4.
The data will be distributed between threads as follows(note that each element is duplicated in 16 threads):

M N -> wave 0 wave 2
| --------------------------/\-------------------------- ------------------------------/\------------------------------
V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ]
[ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ]
[ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ]
[ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ]
wave 1 wave 3
--------------------------/\-------------------------- ------------------------------/\------------------------------
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
}];

let parameters = (
ins
"unsigned":$nonKDim,
"unsigned": $versionMajor,
"unsigned": $versionMinor,
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
"unsigned":$MDim,
"unsigned":$NDim,
"bool":$isTransposed,
"CTALayoutAttr":$CTALayout
);
Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <limits>
#include <numeric>

using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
Expand All @@ -20,7 +21,6 @@ using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getUniqueContigPerThread;
using ::mlir::triton::gpu::MfmaEncodingAttr;
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
Expand Down Expand Up @@ -109,8 +109,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

if (srcLayout.isa<MfmaEncodingAttr>() &&
srcLayout.dyn_cast<MfmaEncodingAttr>().getIsTransposed() &&
if (srcLayout.isa<AMDMfmaEncodingAttr>() &&
srcLayout.dyn_cast<AMDMfmaEncodingAttr>().getIsTransposed() &&
dstLayout.isa<DotOperandEncodingAttr>())
if (isMfmaToDotShortcut(srcTy, dstTy))
return {};
Expand Down
5 changes: 3 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ bool supportMMA(Value value, int version) {
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto mfmaLayout = srcLayout.cast<MfmaEncodingAttr>();
auto mfmaLayout = srcLayout.cast<AMDMfmaEncodingAttr>();
auto dotOperandLayout = dstLayout.cast<DotOperandEncodingAttr>();
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
// improved. In addition, we can enable this shortcut for regular MFMA
Expand All @@ -540,7 +540,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getKWidth() == 4 &&
dotOperandLayout.getParent() == mfmaLayout &&
mfmaLayout.getNonKDim() == 32 && mfmaLayout.getIsTransposed() &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
mfmaLayout.getIsTransposed() &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}

Expand Down
Loading

0 comments on commit 5454905

Please sign in to comment.