Skip to content

Commit

Permalink
[AMD] Refactor SharedToDotOperandMFMA (#3264)
Browse files Browse the repository at this point in the history
This PR updates SharedToDotOperandMFMA.cpp and MFMA.cpp.
- SharedToDotOperandMFMA.cpp is up to date with triton-mlir as of today,
which includes changes until ROCm#482
  - Fixed issue with opaque pointers
- Fixed API for `getMFMAElemsPerInstrForOperands` and
`getMFMARepForOperands`
- MFMA.cpp is synced with triton-mlir@6bb04d, which includes changes
until ROCm#469

Note to @binarman: changes in other files from
ROCm#469 are not included in this PR. We
can bring up the support for mfma 64x4 and 4x64 later.
  • Loading branch information
zhanglx13 authored Mar 4, 2024
1 parent a815d7f commit 107672f
Show file tree
Hide file tree
Showing 5 changed files with 270 additions and 449 deletions.
5 changes: 2 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -819,9 +819,8 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getMFMAElemsPerInstrForOperands(int kWidth, int opIdx) const;
SmallVector<int64_t> getMFMARepForOperands(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth, int opIdx) const;
SmallVector<int64_t> getMFMAInstrShapeForOperands(int kWidth, int opIdx) const;
SmallVector<int64_t> getMFMARepForOperands(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;


}];
Expand Down
33 changes: 19 additions & 14 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1466,24 +1466,29 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getSizePerThread() const {
}

SmallVector<int64_t>
AMDMfmaEncodingAttr::getMFMAElemsPerInstrForOperands(int kWidth,
int opIdx) const {
int64_t nonKDim = getMDim();
assert(nonKDim == 32 || nonKDim == 16);
int64_t kDim = kWidth * (nonKDim == 32 ? 2 : 4);
AMDMfmaEncodingAttr::getMFMAInstrShapeForOperands(int kWidth, int opIdx) const {
unsigned mDim = getMDim();
unsigned nDim = getNDim();
assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) ||
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));
constexpr int waveSize = 64; // MFMA is used on wave64 architectures only
int kGroups = -1;
if (mDim == nDim)
kGroups = waveSize / mDim;
if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64)
kGroups = 1;
int64_t kDim = kWidth * kGroups;
if (opIdx == 0)
return {nonKDim, kDim};
else {
return {mDim, kDim};
else
assert(opIdx == 1);
return {kDim, nonKDim};
}
return {kDim, nDim};
}

SmallVector<int64_t>
AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth,
int opIdx) const {
auto operandTileShape = getMFMAElemsPerInstrForOperands(kWidth, opIdx);
int kWidth, int opIdx) const {
auto operandTileShape = getMFMAInstrShapeForOperands(kWidth, opIdx);
auto warpsPerCTA = getWarpsPerCTA();
if (opIdx == 0)
return {std::max<int64_t>(1, operandShape[0] /
Expand All @@ -1502,8 +1507,8 @@ unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperands(
int warpsPerCTAM = getWarpsPerCTA()[0];
int warpsPerCTAN = getWarpsPerCTA()[1];
constexpr int waveSize = 64;
auto tileSize = getMFMAElemsPerInstrForOperands(kWidth, opIdx);
auto rep = getMFMARepForOperands(shape, eltTy, kWidth, opIdx);
auto tileSize = getMFMAInstrShapeForOperands(kWidth, opIdx);
auto rep = getMFMARepForOperands(shape, kWidth, opIdx);
return rep[0] * rep[1];
}

Expand Down
Loading

0 comments on commit 107672f

Please sign in to comment.