Skip to content

Commit

Permalink
Revert "[AMD][Navi31] Convert WMMA dot op to LLVM (#3199)" (#3284)
Browse files Browse the repository at this point in the history
This reverts commit 9cfad37.


There has been an inflight collision with some refactoring causing a
build break
  • Loading branch information
ThomasRaoux authored Mar 5, 2024
1 parent 9cfad37 commit c1a3fff
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 292 deletions.
28 changes: 14 additions & 14 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ SmallVector<unsigned> getContigPerThread(Attribute layout) {
SmallVector<unsigned> contigPerThread(rank, 1);
contigPerThread[rank - 1] = 2;
return contigPerThread;
} else if (layout.isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>()) {
} else if (layout.isa<AMDMfmaEncodingAttr>()) {
return {1, 1};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
Expand Down Expand Up @@ -286,7 +286,7 @@ SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
ArrayRef<unsigned> ref;
if (auto distributedLayout = layout.dyn_cast<DistributedEncodingTrait>())
return distributedLayout.getCTAsPerCGA();
else if (layout.isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>())
else if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>())
return {1, 1};
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
ref = sharedLayout.getCTALayout().getCTAsPerCGA();
Expand All @@ -299,7 +299,7 @@ SmallVector<unsigned> getCTASplitNum(Attribute layout) {
SmallVector<unsigned> res;
if (auto distributedLayout = layout.dyn_cast<DistributedEncodingTrait>()) {
return distributedLayout.getCTASplitNum();
} else if (layout.isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>()) {
} else if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) {
res.resize(2);
res[0] = res[1] = 1;
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
Expand All @@ -315,7 +315,7 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
SmallVector<unsigned> res;
if (auto distributedLayout = layout.dyn_cast<DistributedEncodingTrait>()) {
res = distributedLayout.getCTAOrder();
} else if (layout.isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>()) {
} else if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) {
return {0, 1};
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
res = SmallVector<unsigned>(sharedLayout.getCTALayout().getCTAOrder());
Expand Down Expand Up @@ -370,8 +370,6 @@ unsigned getNumWarpsPerCTA(Attribute layout) {
warpsPerCTA = distributedLayout.getWarpsPerCTA();
} else if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>())
warpsPerCTA = mfmaLayout.getWarpsPerCTA();
else if (auto wmmaLayout = layout.dyn_cast<AMDWmmaEncodingAttr>())
warpsPerCTA = wmmaLayout.getWarpsPerCTA();
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
return getNumWarpsPerCTA(dotLayout.getParent());
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
Expand Down Expand Up @@ -786,13 +784,15 @@ SmallVector<unsigned>
AMDWmmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of wmma layout");
assert(rank == 2 && "Unexpected rank of mfma layout");

SmallVector<unsigned> elemsPerThread(rank);
auto mnkDim = getMNKDimPerWMMAInstr();
auto nonKDim = getMNKDimPerWMMAInstr()[0];
auto elemsPerThreadPerTile = getSizePerThread();
return {ceil<unsigned>(shape[0], mnkDim[0]) * elemsPerThreadPerTile[0],
ceil<unsigned>(shape[1], mnkDim[1]) * elemsPerThreadPerTile[1]};
return {ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]) *
elemsPerThreadPerTile[0],
ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]) *
elemsPerThreadPerTile[1]};
}

unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Expand Down Expand Up @@ -1594,11 +1594,11 @@ AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,

unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
auto warpsPerCTA = getWarpsPerCTA();
auto instSize = getWMMAElemsPerInstrForOperands();
SmallVector<int64_t> shapePerWarp;
int warpsPerCTAM = getWarpsPerCTA()[0];
int warpsPerCTAN = getWarpsPerCTA()[1];
auto tileSize = getWMMAElemsPerInstrForOperands();
auto rep = getWMMARepForOperands(shape, eltTy, kWidth, opIdx);
return rep[0] * rep[1];
return product(tileSize) * product(rep) * warpsPerCTAN * warpsPerCTAM;
}

SmallVector<int64_t>
Expand Down
19 changes: 0 additions & 19 deletions test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir

This file was deleted.

1 change: 0 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ add_triton_library(TritonAMDGPUToLLVM
ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp
ConvertLayoutOpToLLVM.cpp
DotOpToLLVM/MFMA.cpp
DotOpToLLVM/WMMA.cpp
DotOpToLLVM.cpp
ElementwiseOpToLLVM.cpp
LoadStoreOpToLLVM.cpp
Expand Down
19 changes: 6 additions & 13 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using namespace mlir::triton;
using ::AMD::ConvertTritonGPUOpToLLVMPattern;
using ::AMD::ConvertTritonGPUOpToLLVMPatternBase;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
Expand All @@ -18,10 +17,6 @@ namespace AMD {
LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);

LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
#endif
} // namespace AMD

Expand Down Expand Up @@ -49,14 +44,12 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
.getEncoding()
.dyn_cast<NvidiaMmaEncodingAttr>();
#ifdef USE_ROCM
if (!isOuter) {
auto dEncoding = D.getType().cast<RankedTensorType>().getEncoding();
if (dEncoding.isa<AMDMfmaEncodingAttr>() && supportMFMA(op)) {
return AMD::convertMFMA(op, adaptor, getTypeConverter(), rewriter);
}
if (dEncoding.isa<AMDWmmaEncodingAttr>()) {
return AMD::convertWMMA(op, adaptor, getTypeConverter(), rewriter);
}
AMDMfmaEncodingAttr mfmaLayout = D.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<AMDMfmaEncodingAttr>();
if (!isOuter && mfmaLayout && supportMFMA(op)) {
return AMD::convertMFMA(op, adaptor, getTypeConverter(), rewriter);
}
#endif

Expand Down
245 changes: 0 additions & 245 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp

This file was deleted.

0 comments on commit c1a3fff

Please sign in to comment.