Skip to content

Commit

Permalink
[BW] Add MMAV5 op at the TTGIR level (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRaoux authored Jun 5, 2024
1 parent d62f061 commit 86c85ca
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 2 deletions.
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class DotLike : public TraitBase<ConcreteType, DotLike> {
return op->emitOpError("expected 3 operands");
auto aTy = cast<TensorOrMemDesc>(op->getOperand(0).getType());
auto bTy = cast<TensorOrMemDesc>(op->getOperand(1).getType());
auto cTy = cast<TensorType>(op->getOperand(2).getType());
auto cTy = cast<TensorOrMemDesc>(op->getOperand(2).getType());
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();
auto cShape = cTy.getShape();
Expand Down
12 changes: 12 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,16 @@
#define GET_OP_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc"

namespace mlir {
namespace triton {
namespace nvidia_gpu {

struct TensorMemory : public SideEffects::Resource::Base<TensorMemory> {
StringRef getName() final { return "<TensorMemory>"; }
};

} // namespace nvidia_gpu
} // namespace triton
} // namespace mlir

#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,11 @@ include "mlir/IR/AttrTypeBase.td"
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"

def TTG_TensorMemorySpace : AttrDef<TritonNvidiaGPU_Dialect, "TensorMemorySpace"> {
let mnemonic = "tensor_memory";
let description = [{
Attribute to indicate that the memory descriptor points to tensor memory.
}];
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def TritonNvidiaGPU_Dialect : Dialect {
}];

let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;
let usePropertiesForAttributes = 1;
}

include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td"
Expand Down
14 changes: 14 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -244,5 +244,19 @@ def TTNG_TMAStoreWait : TTNG_Op<"async_tma_store_wait"> {
let assemblyFormat = "attr-dict";
}

def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DotLike]> {
let summary = "block level op mapping to tensorcore gen5 mma";

let description = [{
$d += matrix_multiply($a, $b)
}];

let arguments = (ins TT_MemDescType:$a,
TT_MemDescType:$b,
TT_MemDescType:$d,
DefaultValuedAttr<BoolAttr, "false">:$isAsync);

let assemblyFormat = "$a`,` $b`,` $d attr-dict `:` type($a) `*` type($b) `->` type($d)";
}

#endif
11 changes: 11 additions & 0 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,17 @@ bool supportMMA(triton::DotOp op, int version) {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
auto aElemTy = op.getA().getType().getElementType();
auto bElemTy = op.getB().getType().getElementType();
if (version == 5) {
auto retType = op.getType();
auto retShapePerCTA = getShapePerCTA(retType);
auto rank = retShapePerCTA.size();
auto mod = op->getParentOfType<ModuleOp>();
int numWarps = TritonGPUDialect::getNumWarps(mod);
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 8 == 0))
return false;
return true;
}
if (version == 3) {
if (triton::tools::getBoolEnv("DISABLE_MMA_V3"))
return false;
Expand Down
63 changes: 62 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
versionsSupported = {2};
} else if (computeCapability < 100) {
versionsSupported = {3, 2};
} else if (computeCapability < 110) {
versionsSupported = {5, 2};
} else {
assert(false && "computeCapability not supported");
}
Expand Down Expand Up @@ -335,6 +337,65 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
return success();
}
};

class BlockedToMMAv5 : public mlir::OpRewritePattern<triton::DotOp> {
int computeCapability;
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;

public:
BlockedToMMAv5(mlir::MLIRContext *context, int computeCapability)
: mlir::OpRewritePattern<triton::DotOp>(context),
computeCapability(computeCapability) {}

mlir::LogicalResult
matchAndRewrite(triton::DotOp dotOp,
mlir::PatternRewriter &rewriter) const override {
RankedTensorType oldRetType = dotOp.getType();
if (!oldRetType.getEncoding() ||
mlir::isa<NvidiaMmaEncodingAttr>(oldRetType.getEncoding()))
return failure();

// get MMA encoding for the given number of warps
auto retShapePerCTA = getShapePerCTA(oldRetType);
auto mod = dotOp->getParentOfType<mlir::ModuleOp>();
int numWarps = TritonGPUDialect::getNumWarps(mod);
auto CTALayout = getCTALayout(oldRetType.getEncoding());

int versionMajor = getMMAVersionSafe(computeCapability, dotOp);
if (versionMajor != 5)
return failure();

auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA,
dotOp.getA().getType(), numWarps);
// operands
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = dotOp.getA().getType();
auto oldBType = dotOp.getB().getType();

a = getSharedMemoryMMAOperand(a, rewriter, 0, /*allowTranspose=*/true);
b = getSharedMemoryMMAOperand(b, rewriter, 1, /*allowTranspose=*/true);
MLIRContext *context = dotOp->getContext();
auto defaultCTALayout =
triton::gpu::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1},
/*CTASplitNum=*/{1}, /*CTAOrder=*/{0});
auto accEncoding = triton::gpu::SharedEncodingAttr::get(
context, 1, 1, 1, {0}, defaultCTALayout);
Attribute tensorMemorySpace =
triton::nvidia_gpu::TensorMemorySpaceAttr::get(context);
Type accMemDescType = triton::MemDescType::get(
oldRetType.getShape(), oldRetType.getElementType(), accEncoding,
tensorMemorySpace,
/*mutableMemory=*/true);
auto acc = rewriter.create<triton::gpu::LocalAllocOp>(
dotOp.getLoc(), accMemDescType, dotOp.getOperand(2));
rewriter.create<triton::nvidia_gpu::TCGen5MMAOp>(dotOp.getLoc(), a, b, acc);

rewriter.replaceOpWithNewOp<LocalLoadOp>(dotOp, oldRetType, acc);
return success();
}
};
} // namespace

static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
Expand Down Expand Up @@ -395,7 +456,7 @@ class TritonGPUAccelerateMatmulPass
auto computeCapability = getNVIDIAComputeCapability(m);

mlir::RewritePatternSet patterns(context);
patterns.add<BlockedToMMA>(context, computeCapability);
patterns.add<BlockedToMMA, BlockedToMMAv5>(context, computeCapability);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
Expand Down
10 changes: 10 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,

assert(false && "type not supported");
return {0, 0, 0};
} else if (version == 5) {
unsigned k = 256 / type.getElementTypeBitWidth();
if (shape[0] % 64 != 0 || shape[1] % 8 != 0) {
assert(false && "type not supported");
return {0, 0, 0};
}
// Pick the largest mma shape available.
unsigned m = shape[0] < 128 ? shape[0] : 128;
unsigned n = shape[1] < 256 ? shape[1] : 256;
return {m, n, k};
} else {
assert(false && "version not supported");
return {0, 0};
Expand Down
19 changes: 19 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,25 @@ void AsyncTMACopyLocalToGlobalOp::getEffects(
mlir::triton::gpu::SharedMemory::get());
}

// -- TCGen5MMAOp --
void TCGen5MMAOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Write::get(), getD(),
mlir::triton::nvidia_gpu::TensorMemory::get());
if (isa<triton::gpu::SharedMemorySpaceAttr>(
getA().getType().getMemorySpace())) {
effects.emplace_back(MemoryEffects::Read::get(), getA(),
mlir::triton::gpu::SharedMemory::get());

} else {
effects.emplace_back(MemoryEffects::Read::get(), getA(),
mlir::triton::nvidia_gpu::TensorMemory::get());
}
effects.emplace_back(MemoryEffects::Read::get(), getB(),
mlir::triton::gpu::SharedMemory::get());
}

} // namespace nvidia_gpu
} // namespace triton
} // namespace mlir
22 changes: 22 additions & 0 deletions test/TritonGPU/accelerate-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,25 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
tt.return
}
}


// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:100", "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: mmav5
// CHECK-DAG: %[[A:.+]] = triton_gpu.local_alloc %{{.*}} : (tensor<128x64xf16, #{{.*}}>) -> !tt.memdesc<128x64xf16, #{{.*}}, #triton_gpu.shared_memory>
// CHECK-DAG: %[[B:.+]] = triton_gpu.local_alloc %{{.*}} : (tensor<64x256xf16, #{{.*}}>) -> !tt.memdesc<64x256xf16, #{{.*}}, #triton_gpu.shared_memory>
// CHECK-DAG: %[[ACC:.+]] = triton_gpu.local_alloc %{{.*}} : (tensor<128x256xf32, #{{.*}}>) -> !tt.memdesc<128x256xf32, #{{.*}}, #triton_nvidia_gpu.tensor_memory, mutable>
// CHECK: triton_nvidia_gpu.tc_gen5_mma %[[A]], %[[B]], %[[ACC]] : <128x64xf16, #{{.*}}, #triton_gpu.shared_memory> * <64x256xf16, #{{.*}}, #triton_gpu.shared_memory> -> <128x256xf32, #{{.*}}, #triton_nvidia_gpu.tensor_memory, mutable>
// CHECK: %[[R:.+]] = triton_gpu.local_load %[[ACC]] : !tt.memdesc<128x256xf32, #{{.*}}, #triton_nvidia_gpu.tensor_memory, mutable> -> tensor<128x256xf32
// CHECK: tt.return %[[R]] : tensor<128x256xf32
tt.func public @mmav5(%a: tensor<128x64xf16, #blocked2>, %b: tensor<64x256xf16, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> {
%ad = triton_gpu.convert_layout %a : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
%bd = triton_gpu.convert_layout %b : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
%d = tt.dot %ad, %bd, %c, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
tt.return %d : tensor<128x256xf32, #blocked>
}
}
16 changes: 16 additions & 0 deletions test/TritonNvidiaGPU/ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: triton-opt --split-input-file %s | FileCheck %s

#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: @tcgen5
// CHECK: triton_nvidia_gpu.tc_gen5_mma
tt.func @tcgen5(%a: !tt.memdesc<128x128xf8E5M2, #shared, #triton_gpu.shared_memory>,
%b: !tt.memdesc<128x256xf8E5M2, #shared1, #triton_gpu.shared_memory>,
%c: !tt.memdesc<128x256xf8E5M2, #shared1, #triton_nvidia_gpu.tensor_memory>) {
triton_nvidia_gpu.tc_gen5_mma %a, %b, %c { isAsync = false }:
!tt.memdesc<128x128xf8E5M2, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x256xf8E5M2, #shared1, #triton_gpu.shared_memory>
-> !tt.memdesc<128x256xf8E5M2, #shared1, #triton_nvidia_gpu.tensor_memory>
tt.return
}
}

0 comments on commit 86c85ca

Please sign in to comment.