Skip to content

Commit

Permalink
[AMD] [MFMA] Support dot3d in MFMA layout (#3600)
Browse files Browse the repository at this point in the history
- Support 3d tensor when emitting offsets for mfma layouts
- Support 3d tensors in Shared to dot operand conversion
- Support dot3d in Dialect.cpp
- Replace amd::DecomposeConversion with common::ReduceDataDuplication

---------

Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
  • Loading branch information
binarman and zhanglx13 authored Apr 9, 2024
1 parent cf27ce3 commit 3c2f88b
Show file tree
Hide file tree
Showing 17 changed files with 337 additions and 270 deletions.
1 change: 0 additions & 1 deletion bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
// mlir::registerTritonAMDGPUPasses();

mlir::registerTritonAMDGPUAccelerateMatmul();
mlir::registerTritonAMDGPUDecomposeConversions();
mlir::registerTritonAMDGPUOptimizeEpilogue();
mlir::registerTritonAMDGPURemoveLayoutConversions();
mlir::registerTritonAMDGPUReorderInstructions();
Expand Down
113 changes: 75 additions & 38 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using namespace mlir::triton;

// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
// Operators
#define inttofloat(...) rewriter.create<LLVM::SIToFPOp>(loc, __VA_ARGS__)
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
Expand Down Expand Up @@ -724,7 +725,7 @@ emitBaseIndexWithinCTAForMmaLayoutV2V3(Location loc, RewriterBase &rewriter,
// level in the layout definition, and the tiling order of warpGrp->warp
// must be fixed to meet the HW's needs. We may need to consider to
// explicitly define warpGrpPerCTA for MMAv3 layout.
assert(rank == 2 && "MMAv3 layout is does not support 3D tensor yet");
assert(rank == 2 && "MMAv3 layout does not support 3D tensor yet");
multiDimWarpId[rank - 2] = urem(warpId, warpsPerCTA[rank - 2]);
multiDimWarpId[rank - 1] =
urem(udiv(warpId, warpsPerCTA[rank - 2]), warpsPerCTA[rank - 1]);
Expand Down Expand Up @@ -783,73 +784,96 @@ emitBaseIndexForMfmaLayout(Location loc, RewriterBase &rewriter,
const AMDMfmaEncodingAttr &mfmaLayout,
RankedTensorType type) {
auto shape = type.getShape();
auto rank = shape.size();
assert(rank == 2 || rank == 3);
auto _warpsPerCTA = mfmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
i32_val(_warpsPerCTA[1])};
int nonKDim = mfmaLayout.getMDim();
SmallVector<Value> warpsPerCTA;
for (unsigned i = 0; i < rank; ++i)
warpsPerCTA.push_back(i32_val(_warpsPerCTA[i]));
unsigned mDim = mfmaLayout.getMDim();
unsigned nDim = mfmaLayout.getNDim();
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) ||
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));

Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout));
Value effectiveWarpSize = warpSize;
if (nonKDim == 4) {
if (mDim == 4 && nDim == 4) {
const int uniqueValuesPerWarp = 4;
effectiveWarpSize = i32_val(uniqueValuesPerWarp);
}
Value laneId = urem(threadId, effectiveWarpSize);

Value warpId = udiv(threadId, warpSize);
SmallVector<Value> multiDimWarpId = delinearize(
rewriter, loc, warpId, _warpsPerCTA, triton::gpu::getOrder(mfmaLayout));
if (shape[0] >= nonKDim) {
assert(shape[0] % nonKDim == 0);
multiDimWarpId[0] =
urem(multiDimWarpId[0], i32_val(ceil<unsigned>(shape[0], nonKDim)));
if (shape[rank - 2] >= mDim) {
assert(shape[rank - 2] % mDim == 0);
multiDimWarpId[rank - 2] =
urem(multiDimWarpId[rank - 2],
i32_val(ceil<unsigned>(shape[rank - 2], mDim)));
}
if (shape[1] >= nonKDim) {
assert(shape[1] % nonKDim == 0);
multiDimWarpId[1] =
urem(multiDimWarpId[1], i32_val(ceil<unsigned>(shape[1], nonKDim)));
if (shape[rank - 1] >= nDim) {
assert(shape[rank - 1] % nDim == 0);
multiDimWarpId[rank - 1] =
urem(multiDimWarpId[rank - 1],
i32_val(ceil<unsigned>(shape[rank - 1], nDim)));
}
Value offWarp0 = mul(multiDimWarpId[0], i32_val(nonKDim));
Value offWarp1 = mul(multiDimWarpId[1], i32_val(nonKDim));
Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mDim));
Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(nDim));

SmallVector<Value> multiDimBase(2);
SmallVector<Value> multiDimBase(rank);
if (mfmaLayout.getIsTransposed()) {
multiDimBase[1] =
add(mul(i32_val(4), udiv(laneId, i32_val(nonKDim))), offWarp1);
multiDimBase[0] = add(urem(laneId, i32_val(nonKDim)), offWarp0);
multiDimBase[rank - 1] =
add(mul(i32_val(4), udiv(laneId, i32_val(mDim))), offWarp1);
multiDimBase[rank - 2] = add(urem(laneId, i32_val(mDim)), offWarp0);
} else {
multiDimBase[0] =
add(mul(i32_val(4), udiv(laneId, i32_val(nonKDim))), offWarp0);
multiDimBase[1] = add(urem(laneId, i32_val(nonKDim)), offWarp1);
multiDimBase[rank - 2] =
add(mul(i32_val(4), udiv(laneId, i32_val(nDim))), offWarp0);
multiDimBase[rank - 1] = add(urem(laneId, i32_val(nDim)), offWarp1);
}
// TODO(Lixun): It is assumed when rank = 3, warpsPerCTA is set to
// {numWarps, 1, 1}. We need to generalize the offset computation.
if (rank == 3) {
assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1);
multiDimBase[0] = urem(warpId, i32_val(shape[0]));
}
return multiDimBase;
}

inline void emitMfmaOffsetForCTA(const AMDMfmaEncodingAttr &mfmaLayout,
SmallVector<SmallVector<unsigned>> &offsets,
unsigned ctaOffsetX, unsigned ctaOffsetY) {
auto nonKDim = mfmaLayout.getMDim();
unsigned bOff, unsigned ctaOffsetX,
unsigned ctaOffsetY) {
auto mDim = mfmaLayout.getMDim();
auto nDim = mfmaLayout.getNDim();
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) ||
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));
// MFMA output tile consists of repeated "dot operand B" layout groups along
// row axis. This variable defines number of these groups.
const unsigned numGroups = (nonKDim == 32 ? 4 : 1);
DenseMap<int, int> groups{{4, 1}, {16, 1}, {32, 4}};
unsigned numGroups = groups.at(std::min(mDim, nDim));
const unsigned elemsPerThreadPerGroup = 4;
auto warpSize = getWarpSize(mfmaLayout);
assert(warpSize == 64);
auto shapePerCta = getShapePerCTATile(mfmaLayout);
auto rank = shapePerCta.size();
SmallVector<unsigned> elemOff(rank, 0);
for (unsigned block = 0; block < numGroups; block++) {
unsigned rowOrColOffset =
block * elemsPerThreadPerGroup * warpSize / nonKDim;
block * elemsPerThreadPerGroup * warpSize / std::min(mDim, nDim);
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
if (mfmaLayout.getIsTransposed()) {
offsets.push_back(
{ctaOffsetX * shapePerCta[0],
ctaOffsetY * shapePerCta[1] + elem + rowOrColOffset});
elemOff[rank - 2] = ctaOffsetX * shapePerCta[rank - 2];
elemOff[rank - 1] =
ctaOffsetY * shapePerCta[rank - 1] + elem + rowOrColOffset;
} else {
offsets.push_back({ctaOffsetX * shapePerCta[0] + elem + rowOrColOffset,
ctaOffsetY * shapePerCta[1]});
elemOff[rank - 2] =
ctaOffsetX * shapePerCta[rank - 2] + elem + rowOrColOffset;
elemOff[rank - 1] = ctaOffsetY * shapePerCta[rank - 1];
}
if (rank == 3)
elemOff[0] = bOff;
offsets.push_back(elemOff);
}
}
}
Expand All @@ -862,16 +886,29 @@ emitOffsetForMfmaLayout(const AMDMfmaEncodingAttr &mfmaLayout,
auto shapePerCTA = getShapePerCTA(mfmaLayout, tensorShape);
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto rank = type.getRank();
SmallVector<unsigned> numWarpsPerDim(rank);
SmallVector<unsigned> numReps(rank);
unsigned mDim = mfmaLayout.getMDim();
unsigned nDim = mfmaLayout.getNDim();
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) ||
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));
SmallVector<unsigned> shapePerWarp(rank, 1);
shapePerWarp[rank - 2] = mDim;
shapePerWarp[rank - 1] = nDim;
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA = std::min<unsigned>(tensorShape[d], shapePerCTA[d]);
unsigned inPerWarp = ceil<unsigned>(inPerCTA, warpsPerCTA[d]);
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, mfmaLayout.getMDim());
numReps[d] = ceil<unsigned>(inPerWarp, shapePerWarp[d]);
}

for (unsigned i = 0; i < numWarpsPerDim[0]; ++i) {
for (unsigned j = 0; j < numWarpsPerDim[1]; ++j) {
emitMfmaOffsetForCTA(mfmaLayout, offsets, i, j);
unsigned repBatch = rank == 3 ? numReps[0] : 1;
auto warpsPerBatch =
rank == 3 ? std::min<unsigned>(tensorShape[0], warpsPerCTA[0]) : 1;

for (unsigned b = 0; b < repBatch; ++b) {
for (unsigned i = 0; i < numReps[rank - 2]; ++i) {
for (unsigned j = 0; j < numReps[rank - 1]; ++j) {
emitMfmaOffsetForCTA(mfmaLayout, offsets, b * warpsPerBatch, i, j);
}
}
}
return offsets;
Expand Down
8 changes: 6 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,12 @@ bool supportMFMA(triton::DotOp op) {
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();

assert(aShape[1] == bShape[0]);
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1]))
auto rank = aShape.size();
auto M = aShape[rank - 2];
auto N = bShape[rank - 1];
auto K = aShape[rank - 1];
assert(K == bShape[rank - 2]);
if (!supportMFMAGranularity(M, N, K))
return false;

return true;
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
assert(rank == 2);
SmallVector<Value> multiDimOffset(rank);
if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) {
emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0],
emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0],
multiDimCTAInRepId[1]);
} else if (auto wmmaLayout = layout.dyn_cast<AMDWmmaEncodingAttr>()) {
emitWmmaOffsetForCTA(wmmaLayout, offsets, multiDimCTAInRepId[0],
Expand Down
Loading

0 comments on commit 3c2f88b

Please sign in to comment.