Skip to content

Commit

Permalink
[AMD][Navi31] Convert WMMA dot op to LLVM
Browse files Browse the repository at this point in the history
-Add WMMA conversion logic for dot operation
-Fix helper hunctions for WMMA layout
-Add lit test for WMMA dot operation conversion

Signed-off-by: joviliast <iveselov.nn@gmail.com>
  • Loading branch information
joviliast committed Mar 1, 2024
1 parent 38565ba commit 8d1c067
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 20 deletions.
28 changes: 14 additions & 14 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ SmallVector<unsigned> getContigPerThread(Attribute layout) {
SmallVector<unsigned> contigPerThread(rank, 1);
contigPerThread[rank - 1] = 2;
return contigPerThread;
} else if (layout.isa<AMDMfmaEncodingAttr>()) {
} else if (layout.isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>()) {
return {1, 1};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
Expand Down Expand Up @@ -286,7 +286,7 @@ SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
ArrayRef<unsigned> ref;
if (auto distributedLayout = layout.dyn_cast<DistributedEncodingTrait>())
return distributedLayout.getCTAsPerCGA();
else if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>())
else if (layout.isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>())
return {1, 1};
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
ref = sharedLayout.getCTALayout().getCTAsPerCGA();
Expand All @@ -299,7 +299,7 @@ SmallVector<unsigned> getCTASplitNum(Attribute layout) {
SmallVector<unsigned> res;
if (auto distributedLayout = layout.dyn_cast<DistributedEncodingTrait>()) {
return distributedLayout.getCTASplitNum();
} else if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) {
} else if (layout.isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>()) {
res.resize(2);
res[0] = res[1] = 1;
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
Expand All @@ -315,7 +315,7 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
SmallVector<unsigned> res;
if (auto distributedLayout = layout.dyn_cast<DistributedEncodingTrait>()) {
res = distributedLayout.getCTAOrder();
} else if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) {
} else if (layout.isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>()) {
return {0, 1};
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
res = SmallVector<unsigned>(sharedLayout.getCTALayout().getCTAOrder());
Expand Down Expand Up @@ -370,6 +370,8 @@ unsigned getNumWarpsPerCTA(Attribute layout) {
warpsPerCTA = distributedLayout.getWarpsPerCTA();
} else if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>())
warpsPerCTA = mfmaLayout.getWarpsPerCTA();
else if (auto wmmaLayout = layout.dyn_cast<AMDWmmaEncodingAttr>())
warpsPerCTA = wmmaLayout.getWarpsPerCTA();
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
return getNumWarpsPerCTA(dotLayout.getParent());
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
Expand Down Expand Up @@ -784,15 +786,13 @@ SmallVector<unsigned>
AMDWmmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of mfma layout");
assert(rank == 2 && "Unexpected rank of wmma layout");

SmallVector<unsigned> elemsPerThread(rank);
auto nonKDim = getMNKDimPerWMMAInstr()[0];
auto mnkDim = getMNKDimPerWMMAInstr();
auto elemsPerThreadPerTile = getSizePerThread();
return {ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]) *
elemsPerThreadPerTile[0],
ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]) *
elemsPerThreadPerTile[1]};
return {ceil<unsigned>(shape[0], mnkDim[0]) * elemsPerThreadPerTile[0],
ceil<unsigned>(shape[1], mnkDim[1]) * elemsPerThreadPerTile[1]};
}

unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Expand Down Expand Up @@ -1588,11 +1588,11 @@ AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,

unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
int warpsPerCTAM = getWarpsPerCTA()[0];
int warpsPerCTAN = getWarpsPerCTA()[1];
auto tileSize = getWMMAElemsPerInstrForOperands();
auto warpsPerCTA = getWarpsPerCTA();
auto instSize = getWMMAElemsPerInstrForOperands();
SmallVector<int64_t> shapePerWarp;
auto rep = getWMMARepForOperands(shape, eltTy, kWidth, opIdx);
return product(tileSize) * product(rep) * warpsPerCTAN * warpsPerCTAM;
return rep[0] * rep[1];
}

SmallVector<int64_t>
Expand Down
19 changes: 19 additions & 0 deletions test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm | FileCheck %s

// CHECK-LABEL: wmma_dot
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @wmma_dot(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>, %arg2: tensor<16x16xf16, #mma>) {
// CHECK-COUNT-2: llvm.extractvalue %{{.*}} : !llvm.struct<(f16)>
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
// CHECK: llvm.mlir.undef : vector<16xf16>
// CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xf16>
// CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (f16, f16, vector<16xf16>, i1) -> vector<16xf16>
%0 = tt.dot %arg0, %arg1, %arg2 {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<16x16xf16, #mma>
// CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16>
// CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
tt.return
}
}
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_triton_library(TritonAMDGPUToLLVM
ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp
ConvertLayoutOpToLLVM.cpp
DotOpToLLVM/MFMA.cpp
DotOpToLLVM/WMMA.cpp
DotOpToLLVM.cpp
ElementwiseOpToLLVM.cpp
LoadStoreOpToLLVM.cpp
Expand Down
19 changes: 13 additions & 6 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using ::AMD::ConvertTritonGPUOpToLLVMPattern;
using ::AMD::ConvertTritonGPUOpToLLVMPatternBase;
using ::AMD::TritonGPUToLLVMTypeConverter;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
Expand All @@ -18,6 +19,10 @@ namespace AMD {
LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);

LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
#endif
} // namespace AMD

Expand Down Expand Up @@ -45,12 +50,14 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
.getEncoding()
.dyn_cast<NvidiaMmaEncodingAttr>();
#ifdef USE_ROCM
AMDMfmaEncodingAttr mfmaLayout = D.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<AMDMfmaEncodingAttr>();
if (!isOuter && mfmaLayout && supportMFMA(op)) {
return AMD::convertMFMA(op, adaptor, getTypeConverter(), rewriter);
if (!isOuter) {
auto dEncoding = D.getType().cast<RankedTensorType>().getEncoding();
if (dEncoding.isa<AMDMfmaEncodingAttr>() && supportMFMA(op)) {
return AMD::convertMFMA(op, adaptor, getTypeConverter(), rewriter);
}
if (dEncoding.isa<AMDWmmaEncodingAttr>()) {
return AMD::convertWMMA(op, adaptor, getTypeConverter(), rewriter);
}
}
#endif

Expand Down
245 changes: 245 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
/*
* Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

#include "../DotOpToLLVM.h"
#include "Utility.h"

#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"

using namespace mlir;
using namespace mlir::triton;

namespace AMD {
namespace {

using ::AMD::TritonGPUToLLVMTypeConverter;
using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;

enum class WMMAInstrType : uint8_t {
// D = AB + C;
// typeof(D) == typeof(C)
// typeof(A) == typeof(B)
// typeof(D), typeof(A):
FP32_FP16,
FP32_BF16,
FP16_FP16,
BF16_BF16,
INT32_IU8,
INT32_IU4,
NOT_APPLICABLE,
};

using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;

ValueTable
getValuesFromDotOperandLayoutStruct(ConversionPatternRewriter &rewriter,
TritonGPUToLLVMTypeConverter *typeConverter,
Value value, int n0, int n1, Type type,
Location loc) {
auto elems = typeConverter->unpackLLElements(loc, value, rewriter);
ValueTable vals;
for (int i = 0; i < n0; i++) {
for (int j = 0; j < n1; j++) {
vals[{i, j}] = elems[n1 * i + j];
}
}
return vals;
}

static WMMAInstrType getWMMAInstrTypeFromDot(DotOp op) {
auto aOperandTy = op.getA().getType();
auto aTensorTy = aOperandTy.cast<RankedTensorType>();
auto aElemTy = aTensorTy.getElementType();
auto bOperandTy = op.getB().getType();
auto bTensorTy = bOperandTy.cast<RankedTensorType>();
auto bElemTy = bTensorTy.getElementType();
assert(aElemTy == bElemTy);
auto cOperandTy = op.getC().getType();
auto cTensorTy = cOperandTy.cast<RankedTensorType>();
auto cElemTy = cTensorTy.getElementType();
auto dOperandTy = op.getD().getType();
auto dTensorTy = dOperandTy.cast<RankedTensorType>();
auto dElemTy = dTensorTy.getElementType();
assert(cElemTy == dElemTy);

if (dElemTy.isF32() && aElemTy.isF16())
return WMMAInstrType::FP32_FP16;
if (dElemTy.isF32() && aElemTy.isBF16())
return WMMAInstrType::FP32_BF16;
if (dElemTy.isF16() && aElemTy.isF16())
return WMMAInstrType::FP16_FP16;
if (dElemTy.isBF16() && aElemTy.isBF16())
return WMMAInstrType::BF16_BF16;
if (dElemTy.isSignedInteger(32) && aElemTy.isUnsignedInteger(8))
return WMMAInstrType::INT32_IU8;
if (dElemTy.isSignedInteger(32) && aElemTy.isUnsignedInteger(4))
return WMMAInstrType::INT32_IU4;

return WMMAInstrType::NOT_APPLICABLE;
}

Value generateWMMAOp(ConversionPatternRewriter &rewriter, Location loc,
WMMAInstrType wmmaType, Value valA, Value valB,
Value valC) {
auto resType = valC.getType();
Value falseFlag = int_val(1, false);
switch (wmmaType) {
case WMMAInstrType::FP32_FP16:
return rewriter.create<ROCDL::wmma_f32_16x16x16_f16>(
loc, TypeRange{resType}, ValueRange{valA, valB, valC});
case WMMAInstrType::FP32_BF16:
return rewriter.create<ROCDL::wmma_f32_16x16x16_bf16>(
loc, TypeRange{resType}, ValueRange{valA, valB, valC});
case WMMAInstrType::FP16_FP16:
return rewriter.create<ROCDL::wmma_f16_16x16x16_f16>(
loc, TypeRange{resType}, ValueRange{valA, valB, valC, falseFlag});
case WMMAInstrType::BF16_BF16:
return rewriter.create<ROCDL::wmma_bf16_16x16x16_bf16>(
loc, TypeRange{resType}, ValueRange{valA, valB, valC, falseFlag});
case WMMAInstrType::INT32_IU8:
return rewriter.create<ROCDL::wmma_i32_16x16x16_iu8>(
loc, TypeRange{resType}, ValueRange{valA, valB, valC, falseFlag});
case WMMAInstrType::INT32_IU4:
return rewriter.create<ROCDL::wmma_i32_16x16x16_iu4>(
loc, TypeRange{resType}, ValueRange{valA, valB, valC, falseFlag});
default:
llvm::report_fatal_error("WMMA data type not supported");
}
return Value();
}

// Conduct the Dot conversion.
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TritonGPUToLLVMTypeConverter *typeConverter) {
auto wmmaLayout = op.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<AMDWmmaEncodingAttr>();
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();
auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr();
auto wmmaInstrType = getWMMAInstrTypeFromDot(op);

auto loc = op.getLoc();
Value a = op.getA();
Value b = op.getB();
Value d = op.getD();
auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
auto elemTy = aTensorTy.getElementType();

auto aEncoding = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto bEncoding = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
int kWidth = aEncoding.getKWidth();

auto repA =
wmmaLayout.getWMMARepForOperands(aTensorTy.getShape(), elemTy, kWidth, 0);
auto repB =
wmmaLayout.getWMMARepForOperands(bTensorTy.getShape(), elemTy, kWidth, 1);

assert(repA[1] == repB[0]);

Value loadedA = adaptor.getA();
Value loadedB = adaptor.getB();
Value loadedC = adaptor.getC();
auto numRepM = repA[0];
auto numRepN = repB[1];
auto numRepK = repA[1];

ValueTable ha = getValuesFromDotOperandLayoutStruct(
rewriter, typeConverter, loadedA, numRepM, numRepK,
aTensorTy.getElementType(), loc);
ValueTable hb = getValuesFromDotOperandLayoutStruct(
rewriter, typeConverter, loadedB, numRepN, numRepK,
aTensorTy.getElementType(), loc);
auto dstElemTy = dTensorTy.getElementType();
auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter);

unsigned warpSize = triton::gpu::getWarpSize(wmmaLayout);
// TODO get rid of magic numbers
unsigned vgprElemWidth = 32;
unsigned paddedOutputElemSize =
vgprElemWidth / dstElemTy.getIntOrFloatBitWidth();
// compute number of output elements that each thread holds for one WMMA
// instruction.
auto elemsPerVec = mnkDim[0] * mnkDim[1] * paddedOutputElemSize / warpSize;
auto dElemsToStorePerThread = mnkDim[0] * mnkDim[1] / warpSize;
auto vecTy = vec_ty(dstElemTy, elemsPerVec);
for (int m = 0; m < numRepM; ++m) {
for (int n = 0; n < numRepN; ++n) {
Value acc = undef(vecTy);
for (unsigned v = 0; v < dElemsToStorePerThread; ++v) {
acc = insert_element(vecTy, acc,
fc[m * numRepN * dElemsToStorePerThread +
n * dElemsToStorePerThread + v],
i32_val(v * paddedOutputElemSize));
}
for (size_t k = 0; k < numRepK; k++) {
acc = generateWMMAOp(rewriter, loc, wmmaInstrType, ha[{m, k}],
hb[{n, k}], acc);
}
for (unsigned v = 0; v < dElemsToStorePerThread; ++v) {
fc[m * numRepN * dElemsToStorePerThread + n * dElemsToStorePerThread +
v] =
extract_element(dstElemTy, acc, i32_val(v * paddedOutputElemSize));
}
}
}

// replace with new packed result
Type structTy = LLVM::LLVMStructType::getLiteral(
wmmaLayout.getContext(), SmallVector<Type>(fc.size(), dstElemTy));
Value res = typeConverter->packLLElements(loc, fc, rewriter, structTy);
rewriter.replaceOp(op, res);
return success();
}

} // namespace

LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) {
auto rankedTType = [](Value tensor) {
return tensor.getType().cast<RankedTensorType>();
};

assert(rankedTType(op.getA()).getEncoding().isa<DotOperandEncodingAttr>() &&
rankedTType(op.getB()).getEncoding().isa<DotOperandEncodingAttr>() &&
"Both $a and %b should be DotOperand layout.");

auto cTensorTy = rankedTType(op.getC());
auto dTensorTy = rankedTType(op.getD());
assert(cTensorTy.getEncoding().isa<AMDWmmaEncodingAttr>() &&
"Currently, we only support $c with a wmma layout.");

assert(cTensorTy.getShape()[0] == dTensorTy.getShape()[0] &&
cTensorTy.getShape()[1] == dTensorTy.getShape()[1] &&
"DotOp's $c operand should pass the same number of values as $d");

return convertDot(op, adaptor, rewriter, typeConverter);
}
} // namespace AMD

0 comments on commit 8d1c067

Please sign in to comment.