From 86c85caeadfea540f2a734360f323bb4ecf75116 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 5 Jun 2024 14:11:06 -0700 Subject: [PATCH] [BW] Add MMAV5 op at the TTGIR level (#9) --- include/triton/Dialect/Triton/IR/Traits.h | 2 +- .../Dialect/TritonNvidiaGPU/IR/Dialect.h | 12 ++++ .../IR/TritonNvidiaGPUAttrDefs.td | 7 +++ .../IR/TritonNvidiaGPUDialect.td | 2 + .../TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td | 14 +++++ lib/Analysis/Utility.cpp | 11 ++++ .../TritonGPU/Transforms/AccelerateMatmul.cpp | 63 ++++++++++++++++++- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 10 +++ lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 19 ++++++ test/TritonGPU/accelerate-matmul.mlir | 22 +++++++ test/TritonNvidiaGPU/ops.mlir | 16 +++++ 11 files changed, 176 insertions(+), 2 deletions(-) create mode 100644 test/TritonNvidiaGPU/ops.mlir diff --git a/include/triton/Dialect/Triton/IR/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h index a768f279446c..6e554aed2358 100644 --- a/include/triton/Dialect/Triton/IR/Traits.h +++ b/include/triton/Dialect/Triton/IR/Traits.h @@ -71,7 +71,7 @@ class DotLike : public TraitBase { return op->emitOpError("expected 3 operands"); auto aTy = cast(op->getOperand(0).getType()); auto bTy = cast(op->getOperand(1).getType()); - auto cTy = cast(op->getOperand(2).getType()); + auto cTy = cast(op->getOperand(2).getType()); auto aShape = aTy.getShape(); auto bShape = bTy.getShape(); auto cShape = cTy.getShape(); diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h index 279faf9a434a..9f24b6ba7d6a 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h @@ -41,4 +41,16 @@ #define GET_OP_CLASSES #include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc" +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +struct TensorMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir + #endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td index 936535bb039a..e17d9871236c 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td @@ -26,4 +26,11 @@ include "mlir/IR/AttrTypeBase.td" include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" include "triton/Dialect/Triton/IR/TritonInterfaces.td" +def TTG_TensorMemorySpace : AttrDef { + let mnemonic = "tensor_memory"; + let description = [{ + Attribute to indicate that the memory descriptor points to tensor memory. + }]; +} + #endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td index 67ece715d2f6..2ec4d61864cd 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td @@ -60,6 +60,8 @@ def TritonNvidiaGPU_Dialect : Dialect { }]; let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; } include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td" diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 1a23c6747e9d..1c7a848ff5a5 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -244,5 +244,19 @@ def TTNG_TMAStoreWait : TTNG_Op<"async_tma_store_wait"> { let assemblyFormat = "attr-dict"; } +def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods, DotLike]> { + let summary = "block level op mapping to tensorcore gen5 mma"; + + let description = [{ + $d += matrix_multiply($a, $b) + }]; + + let arguments = (ins TT_MemDescType:$a, + TT_MemDescType:$b, + TT_MemDescType:$d, + DefaultValuedAttr:$isAsync); + + let assemblyFormat = "$a`,` $b`,` $d attr-dict `:` type($a) `*` type($b) `->` type($d)"; +} #endif diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 32cc43c9d5d2..161c04a353fa 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -538,6 +538,17 @@ bool supportMMA(triton::DotOp op, int version) { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 auto aElemTy = op.getA().getType().getElementType(); auto bElemTy = op.getB().getType().getElementType(); + if (version == 5) { + auto retType = op.getType(); + auto retShapePerCTA = getShapePerCTA(retType); + auto rank = retShapePerCTA.size(); + auto mod = op->getParentOfType(); + int numWarps = TritonGPUDialect::getNumWarps(mod); + if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && + retShapePerCTA[rank - 1] % 8 == 0)) + return false; + return true; + } if (version == 3) { if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) return false; diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index ab7e96945d7c..49201a889344 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -27,6 +27,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { versionsSupported = {2}; } else if (computeCapability < 100) { versionsSupported = {3, 2}; + } else if (computeCapability < 110) { + versionsSupported = {5, 2}; } else { assert(false && "computeCapability not supported"); } @@ -335,6 +337,65 @@ class BlockedToMMA : public mlir::OpRewritePattern { return success(); } }; + +class BlockedToMMAv5 : public mlir::OpRewritePattern { + int computeCapability; + mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding + mutable llvm::DenseMap dotOpInstNs; + +public: + BlockedToMMAv5(mlir::MLIRContext *context, int computeCapability) + : mlir::OpRewritePattern(context), + computeCapability(computeCapability) {} + + mlir::LogicalResult + matchAndRewrite(triton::DotOp dotOp, + mlir::PatternRewriter &rewriter) const override { + RankedTensorType oldRetType = dotOp.getType(); + if (!oldRetType.getEncoding() || + mlir::isa(oldRetType.getEncoding())) + return failure(); + + // get MMA encoding for the given number of warps + auto retShapePerCTA = getShapePerCTA(oldRetType); + auto mod = dotOp->getParentOfType(); + int numWarps = TritonGPUDialect::getNumWarps(mod); + auto CTALayout = getCTALayout(oldRetType.getEncoding()); + + int versionMajor = getMMAVersionSafe(computeCapability, dotOp); + if (versionMajor != 5) + return failure(); + + auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, + dotOp.getA().getType(), numWarps); + // operands + Value a = dotOp.getA(); + Value b = dotOp.getB(); + auto oldAType = dotOp.getA().getType(); + auto oldBType = dotOp.getB().getType(); + + a = getSharedMemoryMMAOperand(a, rewriter, 0, /*allowTranspose=*/true); + b = getSharedMemoryMMAOperand(b, rewriter, 1, /*allowTranspose=*/true); + MLIRContext *context = dotOp->getContext(); + auto defaultCTALayout = + triton::gpu::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto accEncoding = triton::gpu::SharedEncodingAttr::get( + context, 1, 1, 1, {0}, defaultCTALayout); + Attribute tensorMemorySpace = + triton::nvidia_gpu::TensorMemorySpaceAttr::get(context); + Type accMemDescType = triton::MemDescType::get( + oldRetType.getShape(), oldRetType.getElementType(), accEncoding, + tensorMemorySpace, + /*mutableMemory=*/true); + auto acc = rewriter.create( + dotOp.getLoc(), accMemDescType, dotOp.getOperand(2)); + rewriter.create(dotOp.getLoc(), a, b, acc); + + rewriter.replaceOpWithNewOp(dotOp, oldRetType, acc); + return success(); + } +}; } // namespace static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, @@ -395,7 +456,7 @@ class TritonGPUAccelerateMatmulPass auto computeCapability = getNVIDIAComputeCapability(m); mlir::RewritePatternSet patterns(context); - patterns.add(context, computeCapability); + patterns.add(context, computeCapability); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 57c535959ae2..08433d177613 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -69,6 +69,16 @@ SmallVector mmaVersionToInstrShape(int version, assert(false && "type not supported"); return {0, 0, 0}; + } else if (version == 5) { + unsigned k = 256 / type.getElementTypeBitWidth(); + if (shape[0] % 64 != 0 || shape[1] % 8 != 0) { + assert(false && "type not supported"); + return {0, 0, 0}; + } + // Pick the largest mma shape available. + unsigned m = shape[0] < 128 ? shape[0] : 128; + unsigned n = shape[1] < 256 ? shape[1] : 256; + return {m, n, k}; } else { assert(false && "version not supported"); return {0, 0}; diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 36489b2d2bf3..72b49761a89f 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -179,6 +179,25 @@ void AsyncTMACopyLocalToGlobalOp::getEffects( mlir::triton::gpu::SharedMemory::get()); } +// -- TCGen5MMAOp -- +void TCGen5MMAOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), getD(), + mlir::triton::nvidia_gpu::TensorMemory::get()); + if (isa( + getA().getType().getMemorySpace())) { + effects.emplace_back(MemoryEffects::Read::get(), getA(), + mlir::triton::gpu::SharedMemory::get()); + + } else { + effects.emplace_back(MemoryEffects::Read::get(), getA(), + mlir::triton::nvidia_gpu::TensorMemory::get()); + } + effects.emplace_back(MemoryEffects::Read::get(), getB(), + mlir::triton::gpu::SharedMemory::get()); +} + } // namespace nvidia_gpu } // namespace triton } // namespace mlir diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index d37c9ebd901a..155320b89a54 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -142,3 +142,25 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return } } + + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:100", "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: mmav5 + // CHECK-DAG: %[[A:.+]] = triton_gpu.local_alloc %{{.*}} : (tensor<128x64xf16, #{{.*}}>) -> !tt.memdesc<128x64xf16, #{{.*}}, #triton_gpu.shared_memory> + // CHECK-DAG: %[[B:.+]] = triton_gpu.local_alloc %{{.*}} : (tensor<64x256xf16, #{{.*}}>) -> !tt.memdesc<64x256xf16, #{{.*}}, #triton_gpu.shared_memory> + // CHECK-DAG: %[[ACC:.+]] = triton_gpu.local_alloc %{{.*}} : (tensor<128x256xf32, #{{.*}}>) -> !tt.memdesc<128x256xf32, #{{.*}}, #triton_nvidia_gpu.tensor_memory, mutable> + // CHECK: triton_nvidia_gpu.tc_gen5_mma %[[A]], %[[B]], %[[ACC]] : <128x64xf16, #{{.*}}, #triton_gpu.shared_memory> * <64x256xf16, #{{.*}}, #triton_gpu.shared_memory> -> <128x256xf32, #{{.*}}, #triton_nvidia_gpu.tensor_memory, mutable> + // CHECK: %[[R:.+]] = triton_gpu.local_load %[[ACC]] : !tt.memdesc<128x256xf32, #{{.*}}, #triton_nvidia_gpu.tensor_memory, mutable> -> tensor<128x256xf32 + // CHECK: tt.return %[[R]] : tensor<128x256xf32 + tt.func public @mmav5(%a: tensor<128x64xf16, #blocked2>, %b: tensor<64x256xf16, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> { + %ad = triton_gpu.convert_layout %a : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %bd = triton_gpu.convert_layout %b : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %d = tt.dot %ad, %bd, %c, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + tt.return %d : tensor<128x256xf32, #blocked> + } +} diff --git a/test/TritonNvidiaGPU/ops.mlir b/test/TritonNvidiaGPU/ops.mlir new file mode 100644 index 000000000000..f01f5de87e3a --- /dev/null +++ b/test/TritonNvidiaGPU/ops.mlir @@ -0,0 +1,16 @@ +// RUN: triton-opt --split-input-file %s | FileCheck %s + +#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: @tcgen5 + // CHECK: triton_nvidia_gpu.tc_gen5_mma + tt.func @tcgen5(%a: !tt.memdesc<128x128xf8E5M2, #shared, #triton_gpu.shared_memory>, + %b: !tt.memdesc<128x256xf8E5M2, #shared1, #triton_gpu.shared_memory>, + %c: !tt.memdesc<128x256xf8E5M2, #shared1, #triton_nvidia_gpu.tensor_memory>) { + triton_nvidia_gpu.tc_gen5_mma %a, %b, %c { isAsync = false }: + !tt.memdesc<128x128xf8E5M2, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x256xf8E5M2, #shared1, #triton_gpu.shared_memory> + -> !tt.memdesc<128x256xf8E5M2, #shared1, #triton_nvidia_gpu.tensor_memory> + tt.return + } +}