Skip to content

Commit

Permalink
[MFMA] Support 64x4 and 4x64 tile size
Browse files Browse the repository at this point in the history
This PR enables two new MxN tile sizes: 64 x 4 and 4 x 64.
Both of them uses mfma 4x4 instructions.
  • Loading branch information
binarman authored and alefimov-amd committed Jan 17, 2024
1 parent 5da6276 commit ae2eff0
Show file tree
Hide file tree
Showing 14 changed files with 2,769 additions and 195 deletions.
5 changes: 4 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ compared to 1*64 when the hasLeadingOffset is false.
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
int maxPhase = SIMDWidth / perPhase;
// TODO (zhanglx): figure out better parameters for mfma4
if (mfmaEnc.getMDim() == 4 )
auto mDim = mfmaEnc.getMDim();
auto nDim = mfmaEnc.getNDim();
auto nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim;
if (nonKDim == 4 )
maxPhase = 4;

return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
Expand Down
22 changes: 12 additions & 10 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ enum class MfmaTypeId : uint32_t {
};

struct MfmaInsnGroupSelectKey {
unsigned nonKDim;
unsigned mDim, nDim;
MfmaTypeId elemType;
int mfmaVersion;
};
Expand All @@ -187,23 +187,24 @@ constexpr typename std::underlying_type<T>::type cast_as_underlying(T t) {
struct MfmaInsnGroupSelectKeyInfo
: public llvm::DenseMapInfo<MfmaInsnGroupSelectKey> {
static inline MfmaInsnGroupSelectKey getEmptyKey() {
return {32, MfmaTypeId::Fp32TyId, 0};
return {32, 32, MfmaTypeId::Fp32TyId, 0};
}

static inline MfmaInsnGroupSelectKey getTombstoneKey() {
return {32, MfmaTypeId::Fp32TyId, -1};
return {32, 32, MfmaTypeId::Fp32TyId, -1};
}

static inline bool isEqual(const MfmaInsnGroupSelectKey &lhs,
const MfmaInsnGroupSelectKey &rhs) {
return lhs.nonKDim == rhs.nonKDim && lhs.elemType == rhs.elemType &&
lhs.mfmaVersion == rhs.mfmaVersion;
return lhs.mDim == rhs.mDim && lhs.nDim == rhs.nDim &&
lhs.elemType == rhs.elemType && lhs.mfmaVersion == rhs.mfmaVersion;
}

static unsigned getHashValue(const MfmaInsnGroupSelectKey &key) {
return llvm::detail::combineHashValue(
cast_as_underlying(key.elemType),
llvm::detail::combineHashValue(key.nonKDim, key.mfmaVersion));
auto dimHash = llvm::detail::combineHashValue(key.mDim, key.nDim);
auto verHash = llvm::detail::combineHashValue(dimHash, key.mfmaVersion);
auto elemHash = cast_as_underlying(key.elemType);
return llvm::detail::combineHashValue(elemHash, verHash);
}
};

Expand All @@ -214,8 +215,9 @@ class MfmaInsn {
MfmaInsnAttr attr;

public:
static FailureOr<MfmaInsn> selectMfma(unsigned nonKDim, Type elementTypeA,
Type elementTypeB, int mfmaVersion);
static FailureOr<MfmaInsn> selectMfma(unsigned mDim, unsigned nDim,
Type elementTypeA, Type elementTypeB,
int mfmaVersion);
MfmaInsn(Type elementTypeA, Type elementTypeB, const MfmaInsnAttr &attr);
unsigned getKDim();
unsigned getMDim();
Expand Down
7 changes: 4 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,11 @@ bool supportMMA(triton::DotOp op, int version) {
#ifdef USE_ROCM
static bool supportMFMAGranularity(int m, int n, int k) {
// these limitations are dtype dependent, in future we may relax them
const static std::pair<int, int> mfmaTypes[] = {{32, 8}, {16, 16}, {4, 64}};
const static std::tuple<int, int, int> mfmaTypes[] = {
{32, 32, 8}, {16, 16, 16}, {4, 4, 64}, {64, 4, 4}, {4, 64, 4}};
for (const auto &mfmaType : mfmaTypes) {
auto [granularityMN, granularityK] = mfmaType;
if (m % granularityMN != 0 || n % granularityMN != 0)
auto [granularityM, granularityN, granularityK] = mfmaType;
if (m % granularityM != 0 || n % granularityN != 0)
continue;
if (k % granularityK != 0)
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
* @param smemStrides strides in LDS tensor
* @param loadVecSize number of elements loaded by one operation
* @param iNonKDim non-K dimension of dot operand
* @param iKDim non-K dimension of dot operand
* @return vector (i-th element corresponds to i-th load instruction) of
* 2-element vectors(tensor row and col).
*/
Expand All @@ -140,7 +141,7 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, ArrayRef<Value> smemOffsets,
int loadVecSize, unsigned iNonKDim) {
int loadVecSize, unsigned iNonKDim, unsigned iKDim) {
auto numM = reps[0];
auto numK = reps[1];
const int loadsPerThread = numOfElems / loadVecSize;
Expand All @@ -159,8 +160,16 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
Value laneHOffset;
if (iNonKDim == 32)
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
else
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
else {
// In this configuration wave contains 16 copies of same data
if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) {
laneHOffset = i32_val(0);
} else {
assert(iKDim * iNonKDim / numOfElems == 64 &&
"seems no all threads in wave contain unique elements");
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
}
}

for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Value elemVOffset = _0;
Expand Down Expand Up @@ -197,7 +206,8 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout, unsigned nonKDim) {
SharedEncodingAttr srcLayout, unsigned nonKDim,
unsigned kDim) {
SmallVector<Value> strides{smemObj.strides[0], smemObj.strides[1]};
SmallVector<Value> offsets{smemObj.offsets[0], smemObj.offsets[1]};

Expand All @@ -209,9 +219,9 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc,
vectorSize = numOfElems;
}

auto mapping = computeTensorElemMapping(rewriter, loc, elemsPerInstr, waveId,
laneId, warpsPerGroup, numOfElems,
reps, offsets, vectorSize, nonKDim);
auto mapping = computeTensorElemMapping(
rewriter, loc, elemsPerInstr, waveId, laneId, warpsPerGroup, numOfElems,
reps, offsets, vectorSize, nonKDim, kDim);
llvm::SmallVector<Value> aOffsets(mapping.size());
for (int i = 0; i < mapping.size(); ++i) {
Value row = mapping[i][0];
Expand All @@ -226,7 +236,8 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout, unsigned nonKDim) {
SharedEncodingAttr srcLayout, unsigned nonKDim,
unsigned kDim) {
// transpose reps and offsets, because operand B has layout equal to
// transposed operand A layout
SmallVector<int64_t> tElemsPerInstr{elemsPerInstr[1], elemsPerInstr[0]};
Expand All @@ -241,9 +252,9 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc,
vectorSize = numOfElems;
}

auto mapping = computeTensorElemMapping(rewriter, loc, tElemsPerInstr, waveId,
laneId, warpsPerGroup, numOfElems,
tReps, toffsets, vectorSize, nonKDim);
auto mapping = computeTensorElemMapping(
rewriter, loc, tElemsPerInstr, waveId, laneId, warpsPerGroup, numOfElems,
tReps, toffsets, vectorSize, nonKDim, kDim);
llvm::SmallVector<Value> bOffsets(mapping.size());
for (int i = 0; i < mapping.size(); ++i) {
// swap row and col, because operand B layout is a transposed operand A
Expand Down Expand Up @@ -333,7 +344,8 @@ bool fastPathAvailable(const SharedMemoryObject &smemObj,
// Computes offsets for operand B or transposed operand A
// @param rewriter
// @param loc
// @param elemsPerInstr operand tile shape consumed by one MFMA instruction
// @param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA
// instruction
// @param waveId wave id for the "non K" axis
// @param laneId lane id in warp [0..63]
// @param warpsPerGroup number of warps per horizontal axis
Expand All @@ -349,18 +361,36 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc,
auto numN = reps[1];
SmallVector<Value> offsets(numK * numN * numOfElems);

int lineSize = warpsPerGroup * elemsPerInstr[1] * numN;
Value _nonKDim = i32_val(elemsPerInstr[1]);
Value waveOffset = mul(waveId, i32_val(elemsPerInstr[1]));
auto iKDim = elemsPerInstr[0];
auto iNonKDim = elemsPerInstr[1];
int lineSize = warpsPerGroup * iNonKDim * numN;
Value _nonKDim = i32_val(iNonKDim);
Value waveOffset = mul(waveId, i32_val(iNonKDim));
Value colOffset = urem(laneId, _nonKDim);

for (int block = 0; block < numN; ++block) {
Value blockOffset = i32_val(block * elemsPerInstr[1] * warpsPerGroup);
Value blockOffset = i32_val(block * iNonKDim * warpsPerGroup);
for (int tile = 0; tile < numK; ++tile) {
Value tileOffset = i32_val(tile * elemsPerInstr[0] * lineSize);
Value tileOffset = i32_val(tile * iKDim * lineSize);
for (int elem = 0; elem < numOfElems; ++elem) {
Value halfOffset =
mul(udiv(laneId, _nonKDim), i32_val(numOfElems * lineSize));
// halfOffset is an offset related to wrapping of wave in the tile.
// for example, mfma 32 case (mapping of tensor elements to lane ids in
// wave):
//
// 0 1 2 3 ... 31
// 0 1 2 3 ... 31
// 0 1 2 3 ... 31
// 0 1 2 3 ... 31
// 32 33 34 35 ... 63 <- at this point wave is wrapping
// 32 33 34 35 ... 63
// 32 33 34 35 ... 63
// 32 33 34 35 ... 63
Value halfOffset;
if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4)
halfOffset = i32_val(0);
else
halfOffset =
mul(udiv(laneId, _nonKDim), i32_val(numOfElems * lineSize));
Value rowOffset = add(i32_val(elem * lineSize), halfOffset);
Value elemOffset = add(rowOffset, colOffset);
Value offset =
Expand Down Expand Up @@ -395,8 +425,10 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
int nonKDimIdx = opIdx == 0 ? 0 : 1;

auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
int nonKDim = mfmaLayout.getMDim();
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto mDim = mfmaLayout.getMDim();
auto nDim = mfmaLayout.getNDim();
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) ||
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();

auto aTensorTy = tensor.getType().cast<RankedTensorType>();
Expand All @@ -422,7 +454,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Value spatialWaveId =
getWaveIdInBlock(rewriter, loc, linearWaveId, warpsPerCTA, mfmaInstrNonK,
shape[nonKDimIdx], nonKDimIdx);
int numOfElems = mfmaInstrNonK * mfmaInstrK / iWaveSize;
// number of duplicates of elements in wave
// In case of 64x4 x 4x4 multiplication, 4x4 B operand is duplicated 16 times
int numSubBlocks = 1;
if ((mfmaInstrK == 4 || mfmaInstrK == 1) && mfmaInstrNonK == 4)
numSubBlocks = 16;
int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWaveSize;
assert(numOfElems >= 1);

unsigned int maxNumWarps = shape[nonKDimIdx] / mfmaInstrNonK;
Expand Down Expand Up @@ -465,14 +502,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
// Normal path handles tensors that are k-major, in which case swizzling
// is enabled and it requires a 2-step method to compute the offsets.
if (opIdx == 0) {
offsets = computeOffsetsAType(rewriter, loc, elemsPerInstr, spatialWaveId,
lane, warpsPerGroupNonK, numOfElems,
numReps, smemObj, sharedLayout, nonKDim);
offsets = computeOffsetsAType(
rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerGroupNonK,
numOfElems, numReps, smemObj, sharedLayout, mDim, mfmaInstrK);
} else {
assert(opIdx == 1);
offsets = computeOffsetsBType(rewriter, loc, elemsPerInstr, spatialWaveId,
lane, warpsPerGroupNonK, numOfElems,
numReps, smemObj, sharedLayout, nonKDim);
offsets = computeOffsetsBType(
rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerGroupNonK,
numOfElems, numReps, smemObj, sharedLayout, nDim, mfmaInstrK);
}
smemBase = computeBasePtr(rewriter, loc, smemObj);
}
Expand All @@ -485,8 +522,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
assert(numOfElems % loadsPerThread == 0);

for (int nonK = 0; nonK < numRepNonK; ++nonK) {
Value blockVOffset = i32_val(nonK * mfmaInstrNonK * warpsPerGroupNonK);
Value offAdjust = mul(blockVOffset, i32_val(shape[order[0]]));
int blockNonKOffset = nonK * mfmaInstrNonK * warpsPerGroupNonK;
Value offAdjust = i32_val(blockNonKOffset * shape[order[0]]);
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
Expand Down
19 changes: 12 additions & 7 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ struct DotOpMFMAConversionHelper {
return rewriter.create(loweredOp)->getResult(0);
}

int getNumSubmatrices(Type elementType, int nonKDim) const {
switch (nonKDim) {
int getNumSubmatrices(Type elementType, int mDim, int nDim) const {
if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64)
return 1;
assert(mDim == nDim);
switch (mDim) {
case 32:
case 16:
return 1;
Expand Down Expand Up @@ -162,9 +165,11 @@ struct DotOpMFMAConversionHelper {
// Conduct the Dot conversion.
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const {
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto nonKDim = mfmaLayout.getMDim();
auto mDim = mfmaLayout.getMDim();
auto nDim = mfmaLayout.getNDim();
auto mfmaVersion = mfmaLayout.getVersionMajor();
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) ||
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));

Value a = op.getA();
Value b = op.getB();
Expand All @@ -177,7 +182,7 @@ struct DotOpMFMAConversionHelper {

StringRef mfmaInsnName;
auto maybeMfmaInsn =
MfmaInsn::selectMfma(nonKDim, elemTyA, elemTyB, mfmaVersion);
MfmaInsn::selectMfma(mDim, nDim, elemTyA, elemTyB, mfmaVersion);
if (failed(maybeMfmaInsn))
llvm::report_fatal_error("No match found in MFMA database\n");
else
Expand Down Expand Up @@ -214,8 +219,8 @@ struct DotOpMFMAConversionHelper {
// compute number of output elements that each thread holds for one MFMA
// instruction. subBlocks
const int subBlocks =
getNumSubmatrices(aTensorTy.getElementType(), nonKDim);
auto elemsPerVec = nonKDim * nonKDim * subBlocks / warpSize;
getNumSubmatrices(aTensorTy.getElementType(), mDim, nDim);
auto elemsPerVec = mDim * nDim * subBlocks / warpSize;

auto vecTy = vec_ty(dstElemTy, elemsPerVec);
for (int m = 0; m < numRepM; ++m) {
Expand Down
Loading

0 comments on commit ae2eff0

Please sign in to comment.