Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MFMA] Support 64x4 and 4x64 tile size #469

Merged
merged 3 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,11 @@ compared to 1*64 when the hasLeadingOffset is false.
vecSize = 8;
int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
// TODO (zhanglx): figure out better parameters for mfma4
if (4 == mfmaEnc.getMDim())
maxPhase = 4;
auto mDim = mfmaEnc.getMDim();
auto nDim = mfmaEnc.getNDim();
auto nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim;
if (4 == nonKDim)
maxPhase = 4;
assert(maxPhase > 0);

return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
Expand Down
22 changes: 12 additions & 10 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ enum class MfmaTypeId : uint32_t {
};

struct MfmaInsnGroupSelectKey {
unsigned nonKDim;
unsigned mDim, nDim;
MfmaTypeId elemType;
int mfmaVersion;
};
Expand All @@ -187,23 +187,24 @@ constexpr typename std::underlying_type<T>::type cast_as_underlying(T t) {
struct MfmaInsnGroupSelectKeyInfo
: public llvm::DenseMapInfo<MfmaInsnGroupSelectKey> {
static inline MfmaInsnGroupSelectKey getEmptyKey() {
return {32, MfmaTypeId::Fp32TyId, 0};
return {32, 32, MfmaTypeId::Fp32TyId, 0};
}

static inline MfmaInsnGroupSelectKey getTombstoneKey() {
return {32, MfmaTypeId::Fp32TyId, -1};
return {32, 32, MfmaTypeId::Fp32TyId, -1};
}

static inline bool isEqual(const MfmaInsnGroupSelectKey &lhs,
const MfmaInsnGroupSelectKey &rhs) {
return lhs.nonKDim == rhs.nonKDim && lhs.elemType == rhs.elemType &&
lhs.mfmaVersion == rhs.mfmaVersion;
return lhs.mDim == rhs.mDim && lhs.nDim == rhs.nDim &&
lhs.elemType == rhs.elemType && lhs.mfmaVersion == rhs.mfmaVersion;
}

static unsigned getHashValue(const MfmaInsnGroupSelectKey &key) {
return llvm::detail::combineHashValue(
cast_as_underlying(key.elemType),
llvm::detail::combineHashValue(key.nonKDim, key.mfmaVersion));
auto dimHash = llvm::detail::combineHashValue(key.mDim, key.nDim);
auto verHash = llvm::detail::combineHashValue(dimHash, key.mfmaVersion);
auto elemHash = cast_as_underlying(key.elemType);
return llvm::detail::combineHashValue(elemHash, verHash);
}
};

Expand All @@ -214,8 +215,9 @@ class MfmaInsn {
MfmaInsnAttr attr;

public:
static FailureOr<MfmaInsn> selectMfma(unsigned nonKDim, Type elementTypeA,
Type elementTypeB, int mfmaVersion);
static FailureOr<MfmaInsn> selectMfma(unsigned mDim, unsigned nDim,
Type elementTypeA, Type elementTypeB,
int mfmaVersion);
MfmaInsn(Type elementTypeA, Type elementTypeB, const MfmaInsnAttr &attr);
unsigned getKDim();
unsigned getMDim();
Expand Down
7 changes: 4 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,11 @@ bool supportMMA(triton::DotOp op, int version) {
#ifdef USE_ROCM
static bool supportMFMAGranularity(int m, int n, int k) {
// these limitations are dtype dependent, in future we may relax them
const static std::pair<int, int> mfmaTypes[] = {{32, 8}, {16, 16}, {4, 64}};
const static std::tuple<int, int, int> mfmaTypes[] = {
{32, 32, 8}, {16, 16, 16}, {4, 4, 64}, {64, 4, 4}, {4, 64, 4}};
for (const auto &mfmaType : mfmaTypes) {
auto [granularityMN, granularityK] = mfmaType;
if (m % granularityMN != 0 || n % granularityMN != 0)
auto [granularityM, granularityN, granularityK] = mfmaType;
if (m % granularityM != 0 || n % granularityN != 0)
continue;
if (k % granularityK != 0)
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
* @param smemStrides strides in LDS tensor
* @param loadVecSize number of elements loaded by one operation
* @param iNonKDim non-K dimension of dot operand
* @param iKDim non-K dimension of dot operand
zhanglx13 marked this conversation as resolved.
Show resolved Hide resolved
* @return vector (i-th element corresponds to i-th load instruction) of
* 2-element vectors(tensor row and col).
*/
Expand All @@ -140,7 +141,7 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, ArrayRef<Value> smemOffsets,
int loadVecSize, unsigned iNonKDim) {
int loadVecSize, unsigned iNonKDim, unsigned iKDim) {
auto numM = reps[0];
auto numK = reps[1];
const int loadsPerThread = numOfElems / loadVecSize;
Expand All @@ -159,8 +160,16 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
Value laneHOffset;
if (iNonKDim == 32)
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
else
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
else {
// In this configuration wave contains 16 copies of same data
if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) {
laneHOffset = i32_val(0);
} else {
assert(iKDim * iNonKDim / numOfElems == 64 &&
"seems no all threads in wave contain unique elements");
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
}
}

for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Value elemVOffset = _0;
Expand Down Expand Up @@ -197,7 +206,8 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout, unsigned nonKDim) {
SharedEncodingAttr srcLayout, unsigned nonKDim,
unsigned kDim) {
SmallVector<Value> strides{smemObj.strides[0], smemObj.strides[1]};
SmallVector<Value> offsets{smemObj.offsets[0], smemObj.offsets[1]};

Expand All @@ -209,9 +219,9 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc,
vectorSize = numOfElems;
}

auto mapping = computeTensorElemMapping(rewriter, loc, elemsPerInstr, waveId,
laneId, warpsPerGroup, numOfElems,
reps, offsets, vectorSize, nonKDim);
auto mapping = computeTensorElemMapping(
rewriter, loc, elemsPerInstr, waveId, laneId, warpsPerGroup, numOfElems,
reps, offsets, vectorSize, nonKDim, kDim);
llvm::SmallVector<Value> aOffsets(mapping.size());
for (int i = 0; i < mapping.size(); ++i) {
Value row = mapping[i][0];
Expand All @@ -226,7 +236,8 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout, unsigned nonKDim) {
SharedEncodingAttr srcLayout, unsigned nonKDim,
unsigned kDim) {
// transpose reps and offsets, because operand B has layout equal to
// transposed operand A layout
SmallVector<int64_t> tElemsPerInstr{elemsPerInstr[1], elemsPerInstr[0]};
Expand All @@ -241,9 +252,9 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc,
vectorSize = numOfElems;
}

auto mapping = computeTensorElemMapping(rewriter, loc, tElemsPerInstr, waveId,
laneId, warpsPerGroup, numOfElems,
tReps, toffsets, vectorSize, nonKDim);
auto mapping = computeTensorElemMapping(
rewriter, loc, tElemsPerInstr, waveId, laneId, warpsPerGroup, numOfElems,
tReps, toffsets, vectorSize, nonKDim, kDim);
llvm::SmallVector<Value> bOffsets(mapping.size());
for (int i = 0; i < mapping.size(); ++i) {
// swap row and col, because operand B layout is a transposed operand A
Expand Down Expand Up @@ -333,7 +344,8 @@ bool fastPathAvailable(const SharedMemoryObject &smemObj,
// Computes offsets for operand B or transposed operand A
// @param rewriter
// @param loc
// @param elemsPerInstr operand tile shape consumed by one MFMA instruction
// @param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA
// instruction
// @param waveId wave id for the "non K" axis
// @param laneId lane id in warp [0..63]
// @param warpsPerGroup number of warps per horizontal axis
Expand All @@ -349,18 +361,36 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc,
auto numN = reps[1];
SmallVector<Value> offsets(numK * numN * numOfElems);

int lineSize = warpsPerGroup * elemsPerInstr[1] * numN;
Value _nonKDim = i32_val(elemsPerInstr[1]);
Value waveOffset = mul(waveId, i32_val(elemsPerInstr[1]));
auto iKDim = elemsPerInstr[0];
auto iNonKDim = elemsPerInstr[1];
int lineSize = warpsPerGroup * iNonKDim * numN;
Value _nonKDim = i32_val(iNonKDim);
Value waveOffset = mul(waveId, i32_val(iNonKDim));
Value colOffset = urem(laneId, _nonKDim);

for (int block = 0; block < numN; ++block) {
Value blockOffset = i32_val(block * elemsPerInstr[1] * warpsPerGroup);
Value blockOffset = i32_val(block * iNonKDim * warpsPerGroup);
for (int tile = 0; tile < numK; ++tile) {
Value tileOffset = i32_val(tile * elemsPerInstr[0] * lineSize);
Value tileOffset = i32_val(tile * iKDim * lineSize);
for (int elem = 0; elem < numOfElems; ++elem) {
Value halfOffset =
mul(udiv(laneId, _nonKDim), i32_val(numOfElems * lineSize));
// halfOffset is an offset related to wrapping of wave in the tile.
// for example, mfma 32 case (mapping of tensor elements to lane ids in
// wave):
//
// 0 1 2 3 ... 31
// 0 1 2 3 ... 31
// 0 1 2 3 ... 31
// 0 1 2 3 ... 31
// 32 33 34 35 ... 63 <- at this point wave is wrapping
// 32 33 34 35 ... 63
// 32 33 34 35 ... 63
// 32 33 34 35 ... 63
Value halfOffset;
if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4)
halfOffset = i32_val(0);
else
halfOffset =
mul(udiv(laneId, _nonKDim), i32_val(numOfElems * lineSize));
Value rowOffset = add(i32_val(elem * lineSize), halfOffset);
Value elemOffset = add(rowOffset, colOffset);
Value offset =
Expand Down Expand Up @@ -395,8 +425,10 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
int nonKDimIdx = opIdx == 0 ? 0 : 1;

auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
int nonKDim = mfmaLayout.getMDim();
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto mDim = mfmaLayout.getMDim();
auto nDim = mfmaLayout.getNDim();
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) ||
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();

auto aTensorTy = tensor.getType().cast<RankedTensorType>();
Expand All @@ -422,7 +454,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Value spatialWaveId =
getWaveIdInBlock(rewriter, loc, linearWaveId, warpsPerCTA, mfmaInstrNonK,
shape[nonKDimIdx], nonKDimIdx);
int numOfElems = mfmaInstrNonK * mfmaInstrK / iWaveSize;
// number of duplicates of elements in wave
// In case of 64x4 x 4x4 multiplication, 4x4 B operand is duplicated 16 times
int numSubBlocks = 1;
if ((mfmaInstrK == 4 || mfmaInstrK == 1) && mfmaInstrNonK == 4)
numSubBlocks = 16;
int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWaveSize;
assert(numOfElems >= 1);

unsigned int maxNumWarps = shape[nonKDimIdx] / mfmaInstrNonK;
Expand Down Expand Up @@ -465,14 +502,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
// Normal path handles tensors that are k-major, in which case swizzling
// is enabled and it requires a 2-step method to compute the offsets.
if (opIdx == 0) {
offsets = computeOffsetsAType(rewriter, loc, elemsPerInstr, spatialWaveId,
lane, warpsPerGroupNonK, numOfElems,
numReps, smemObj, sharedLayout, nonKDim);
offsets = computeOffsetsAType(
rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerGroupNonK,
numOfElems, numReps, smemObj, sharedLayout, mDim, mfmaInstrK);
} else {
assert(opIdx == 1);
offsets = computeOffsetsBType(rewriter, loc, elemsPerInstr, spatialWaveId,
lane, warpsPerGroupNonK, numOfElems,
numReps, smemObj, sharedLayout, nonKDim);
offsets = computeOffsetsBType(
rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerGroupNonK,
numOfElems, numReps, smemObj, sharedLayout, nDim, mfmaInstrK);
}
smemBase = computeBasePtr(rewriter, loc, smemObj);
}
Expand All @@ -485,8 +522,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
assert(numOfElems % loadsPerThread == 0);

for (int nonK = 0; nonK < numRepNonK; ++nonK) {
Value blockVOffset = i32_val(nonK * mfmaInstrNonK * warpsPerGroupNonK);
Value offAdjust = mul(blockVOffset, i32_val(shape[order[0]]));
int blockNonKOffset = nonK * mfmaInstrNonK * warpsPerGroupNonK;
Value offAdjust = i32_val(blockNonKOffset * shape[order[0]]);
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
Expand Down
19 changes: 12 additions & 7 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ struct DotOpMFMAConversionHelper {
return rewriter.create(loweredOp)->getResult(0);
}

int getNumSubmatrices(Type elementType, int nonKDim) const {
switch (nonKDim) {
int getNumSubmatrices(Type elementType, int mDim, int nDim) const {
if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64)
return 1;
assert(mDim == nDim);
switch (mDim) {
case 32:
case 16:
return 1;
Expand Down Expand Up @@ -162,9 +165,11 @@ struct DotOpMFMAConversionHelper {
// Conduct the Dot conversion.
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const {
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto nonKDim = mfmaLayout.getMDim();
auto mDim = mfmaLayout.getMDim();
auto nDim = mfmaLayout.getNDim();
auto mfmaVersion = mfmaLayout.getVersionMajor();
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) ||
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));

Value a = op.getA();
Value b = op.getB();
Expand All @@ -177,7 +182,7 @@ struct DotOpMFMAConversionHelper {

StringRef mfmaInsnName;
auto maybeMfmaInsn =
MfmaInsn::selectMfma(nonKDim, elemTyA, elemTyB, mfmaVersion);
MfmaInsn::selectMfma(mDim, nDim, elemTyA, elemTyB, mfmaVersion);
if (failed(maybeMfmaInsn))
llvm::report_fatal_error("No match found in MFMA database\n");
else
Expand Down Expand Up @@ -214,8 +219,8 @@ struct DotOpMFMAConversionHelper {
// compute number of output elements that each thread holds for one MFMA
// instruction. subBlocks
const int subBlocks =
getNumSubmatrices(aTensorTy.getElementType(), nonKDim);
auto elemsPerVec = nonKDim * nonKDim * subBlocks / warpSize;
getNumSubmatrices(aTensorTy.getElementType(), mDim, nDim);
auto elemsPerVec = mDim * nDim * subBlocks / warpSize;

auto vecTy = vec_ty(dstElemTy, elemsPerVec);
for (int m = 0; m < numRepM; ++m) {
Expand Down
Loading