From 8a75feb62663c28810ca07ba7f31b48d3f77ae76 Mon Sep 17 00:00:00 2001 From: eaplatanios Date: Mon, 29 Jul 2024 18:31:49 -0700 Subject: [PATCH 1/2] Got compilation working on Windows. --- .../TritonToTritonGPU/TritonToTritonGPUPass.h | 1 + lib/Analysis/Utility.cpp | 27 ++++++++++++++++--- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 14 +++++++--- lib/Tools/LinearLayout.cpp | 5 ++++ 4 files changed, 40 insertions(+), 7 deletions(-) diff --git a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h index d3da1394e4ea..78917fdfdd7e 100644 --- a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h +++ b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h @@ -3,6 +3,7 @@ #include #include +#include namespace mlir { diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 78357818c07b..ce49463086df 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -21,8 +21,29 @@ namespace mlir { namespace { -using namespace triton; -using namespace triton::gpu; +// We only "import" the symbols that we need to avoid name conflicts. +using triton::AdvanceOp; +using triton::LinearLayout; +using triton::MakeTensorPtrOp; +using triton::gpu::getCTALayout; +using triton::gpu::getCTAsPerCGA; +using triton::gpu::getCTASplitNum; +using triton::gpu::MmaEncodingTrait; +using triton::gpu::getNumCTAs; +using triton::gpu::getOrder; +using triton::gpu::getShapePerCTA; +using triton::gpu::getThreadsPerWarp; +using triton::gpu::getThreadsPerWarpWithUniqueData; +using triton::gpu::getUniqueContigPerThread; +using triton::gpu::getWarpsPerCTA; +using triton::gpu::getWarpsPerCTAWithUniqueData; +using triton::gpu::toLinearLayout; +using triton::gpu::AMDMfmaEncodingAttr; +using triton::gpu::BlockedEncodingAttr; +using triton::gpu::DotOperandEncodingAttr; +using triton::gpu::NvidiaMmaEncodingAttr; +using triton::gpu::SliceEncodingAttr; +using triton::gpu::TritonGPUDialect; int getParentAxis(Attribute layout, int axis) { if (auto sliceEncoding = dyn_cast(layout)) { @@ -514,7 +535,7 @@ bool supportMMA(triton::DotOp op, int version) { } } if (aElemTy.isF32() && bElemTy.isF32()) { - return op.getInputPrecision() == InputPrecision::TF32 && version >= 2; + return op.getInputPrecision() == triton::InputPrecision::TF32 && version >= 2; } return supportMMA(op.getA(), version) && supportMMA(op.getB(), version); } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 2765d4ac91b9..ddb4c680b3a4 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -21,7 +21,13 @@ namespace mlir { -using namespace triton; +// We only "import" the symbols that we need to avoid name conflicts. +using triton::AxisInfo; +using triton::DialectInferLayoutInterface; +using triton::JoinOp; +using triton::ModuleAxisInfoAnalysis; +using triton::PointerType; +using triton::SplitOp; SmallVector mmaVersionToInstrShape(int version, const ArrayRef &shape, @@ -443,7 +449,7 @@ std::optional inferSrcEncoding(Operation *op, Attribute encoding) { op->hasTrait() || op->hasTrait() || isa(op)) { + triton::nvidia_gpu::WarpGroupDotWaitOp>(op)) { return encoding; } @@ -472,7 +478,7 @@ std::optional inferDstEncoding(Operation *op, Attribute encoding) { op->hasTrait() || op->hasTrait() || isa(op)) + triton::nvidia_gpu::WarpGroupDotWaitOp>(op)) return encoding; if (auto reduceOp = dyn_cast(op)) return inferDstEncoding(reduceOp, encoding); @@ -824,7 +830,7 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, } bool isPureUnaryInlineAsm(Operation *op) { - auto inlineAsmOp = dyn_cast(op); + auto inlineAsmOp = dyn_cast(op); if (!inlineAsmOp) return false; return op->getNumOperands() == 1 && op->getNumResults() == 1 && diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 7abc5fe98451..a88611340073 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -429,7 +429,12 @@ int32_t LinearLayout::getNumConsecutiveInOut() const { } } } + +#if defined(_MSC_VER) + int32_t trailingZeros = otherBits != 0 ? _tzcnt_u32(otherBits) : 31; +#else int32_t trailingZeros = otherBits != 0 ? __builtin_ctz(otherBits) : 31; +#endif return 1 << std::min(consec, trailingZeros); } From b5dd2867c11d6f81e5f0a107793b0e9e81bef55c Mon Sep 17 00:00:00 2001 From: eaplatanios Date: Wed, 31 Jul 2024 18:32:49 -0700 Subject: [PATCH 2/2] . --- lib/Analysis/Utility.cpp | 205 +++++++++---------- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 39 ++-- 2 files changed, 117 insertions(+), 127 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index ce49463086df..03297756925d 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -21,32 +21,8 @@ namespace mlir { namespace { -// We only "import" the symbols that we need to avoid name conflicts. -using triton::AdvanceOp; -using triton::LinearLayout; -using triton::MakeTensorPtrOp; -using triton::gpu::getCTALayout; -using triton::gpu::getCTAsPerCGA; -using triton::gpu::getCTASplitNum; -using triton::gpu::MmaEncodingTrait; -using triton::gpu::getNumCTAs; -using triton::gpu::getOrder; -using triton::gpu::getShapePerCTA; -using triton::gpu::getThreadsPerWarp; -using triton::gpu::getThreadsPerWarpWithUniqueData; -using triton::gpu::getUniqueContigPerThread; -using triton::gpu::getWarpsPerCTA; -using triton::gpu::getWarpsPerCTAWithUniqueData; -using triton::gpu::toLinearLayout; -using triton::gpu::AMDMfmaEncodingAttr; -using triton::gpu::BlockedEncodingAttr; -using triton::gpu::DotOperandEncodingAttr; -using triton::gpu::NvidiaMmaEncodingAttr; -using triton::gpu::SliceEncodingAttr; -using triton::gpu::TritonGPUDialect; - int getParentAxis(Attribute layout, int axis) { - if (auto sliceEncoding = dyn_cast(layout)) { + if (auto sliceEncoding = dyn_cast(layout)) { axis = axis < sliceEncoding.getDim() ? axis : axis + 1; return getParentAxis(sliceEncoding.getParent(), axis); } @@ -54,10 +30,11 @@ int getParentAxis(Attribute layout, int axis) { } SmallVector getParentOrder(Attribute layout) { - if (auto sliceEncoding = mlir::dyn_cast(layout)) { + if (auto sliceEncoding = + mlir::dyn_cast(layout)) { return getParentOrder(sliceEncoding.getParent()); } - return getOrder(layout); + return triton::gpu::getOrder(layout); } } // namespace @@ -70,7 +47,7 @@ bool ReduceOpHelper::isReductionOnLayoutFastAxis() { SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { auto srcLayout = getSrcLayout(); - auto order = getOrder(srcLayout); + auto order = triton::gpu::getOrder(srcLayout); auto it = std::find(order.begin(), order.end(), axis); // delete the axis from order order.erase(it); @@ -90,13 +67,14 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { } unsigned threadOffset = 1; - if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + if (auto sliceLayout = + mlir::dyn_cast(srcLayout)) { auto parentLayout = sliceLayout.getParent(); - auto threadsPerWarp = getThreadsPerWarp(parentLayout); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(parentLayout); threadOffset = threadsPerWarp[sliceLayout.getDim()]; } else { - auto threadsPerWarp = getThreadsPerWarp(srcLayout); - auto order = getOrder(srcLayout); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); + auto order = triton::gpu::getOrder(srcLayout); for (unsigned i = 0; i < order.size(); i++) { if (order[i] == axis) break; @@ -112,8 +90,8 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { // TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented // in the future bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { - unsigned numCTAs = getNumCTAs(srcLayout); - assert(numCTAs == getNumCTAs(dstLayout) && + unsigned numCTAs = triton::gpu::getNumCTAs(srcLayout); + assert(numCTAs == triton::gpu::getNumCTAs(dstLayout) && "Invalid layout conversion: the numbers of CTAs of src and dst " "layouts are different"); @@ -123,17 +101,19 @@ bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { // Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not // implemented yet - if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + if (auto sliceLayout = + mlir::dyn_cast(srcLayout)) { auto dim = sliceLayout.getDim(); - auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent()); if (CTAsPerCGA[dim] != 1) llvm::report_fatal_error("Layout conversion to be implemented"); } // Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported - if (auto sliceLayout = mlir::dyn_cast(dstLayout)) { + if (auto sliceLayout = + mlir::dyn_cast(dstLayout)) { auto dim = sliceLayout.getDim(); - auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent()); if (CTAsPerCGA[dim] != 1) return true; } @@ -142,8 +122,8 @@ bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { // srcLayout and dstLayout // Case (2): Do not use dsmem when srcCTALayout == dstCTALayout - auto srcCTALayout = getCTALayout(srcLayout); - auto dstCTALayout = getCTALayout(dstLayout); + auto srcCTALayout = triton::gpu::getCTALayout(srcLayout); + auto dstCTALayout = triton::gpu::getCTALayout(dstLayout); if (srcCTALayout == dstCTALayout) return false; @@ -155,42 +135,45 @@ unsigned ReduceOpHelper::getInterWarpSize() { auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSize(); return std::min(srcReduceDimSize / sizeIntraWarps, - getWarpsPerCTA(getSrcLayout())[axis]); + triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]); } unsigned ReduceOpHelper::getIntraWarpSize() { auto srcReduceDimSize = static_cast(srcShape[axis]); - return std::min(srcReduceDimSize, getThreadsPerWarp(getSrcLayout())[axis]); + return std::min(srcReduceDimSize, + triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]); } unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() { auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData(); - return std::min( - srcReduceDimSize / sizeIntraWarps, - getWarpsPerCTAWithUniqueData(getSrcLayout(), getSrcShape())[axis]); + return std::min(srcReduceDimSize / sizeIntraWarps, + triton::gpu::getWarpsPerCTAWithUniqueData( + getSrcLayout(), getSrcShape())[axis]); } unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() { auto srcReduceDimSize = static_cast(srcShape[axis]); - unsigned elementPerThreads = - getUniqueContigPerThread(getSrcLayout(), getSrcShape())[axis]; - return std::min( - srcReduceDimSize / elementPerThreads, - getThreadsPerWarpWithUniqueData(getSrcLayout(), getSrcShape())[axis]); + unsigned elementPerThreads = triton::gpu::getUniqueContigPerThread( + getSrcLayout(), getSrcShape())[axis]; + return std::min(srcReduceDimSize / elementPerThreads, + triton::gpu::getThreadsPerWarpWithUniqueData( + getSrcLayout(), getSrcShape())[axis]); } unsigned ReduceOpHelper::getThreadsReductionAxis() { auto srcLayout = getSrcLayout(); auto srcShape = getSrcShape(); - return getThreadsPerWarpWithUniqueData(srcLayout, srcShape)[axis] * - getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis]; + return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, + srcShape)[axis] * + triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis]; } bool ReduceOpHelper::isWarpSynchronous() { auto srcLayout = getSrcLayout(); auto srcShape = getSrcShape(); - return getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] == 1; + return triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] == + 1; } SmallVector ReduceOpHelper::getScratchConfig() { @@ -219,7 +202,7 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() { bool ReduceOpHelper::isReduceWithinCTA() { auto axis = getAxis(); auto srcLayout = getSrcLayout(); - auto CTASplitNum = getCTASplitNum(srcLayout); + auto CTASplitNum = triton::gpu::getCTASplitNum(srcLayout); assert(axis < CTASplitNum.size()); return CTASplitNum[axis] == 1; } @@ -232,13 +215,13 @@ bool ReduceOpHelper::isSupportedLayout() { } auto srcLayout = getSrcLayout(); - if (isa(srcLayout)) { + if (isa(srcLayout)) { return true; } - if (auto mmaLayout = dyn_cast(srcLayout)) { + if (auto mmaLayout = dyn_cast(srcLayout)) { return mmaLayout.supportReduction(); } - if (auto sliceLayout = dyn_cast(srcLayout)) { + if (auto sliceLayout = dyn_cast(srcLayout)) { return true; } return false; @@ -257,15 +240,16 @@ unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() { Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); } unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() { - return getThreadsPerWarp(getEncoding())[getAxis()]; + return triton::gpu::getThreadsPerWarp(getEncoding())[getAxis()]; } unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() { - return getThreadsPerWarpWithUniqueData(getEncoding(), getShape())[getAxis()]; + return triton::gpu::getThreadsPerWarpWithUniqueData(getEncoding(), + getShape())[getAxis()]; } unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() { - auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); threadsPerWarp[getAxis()] = 1; return product(threadsPerWarp); } @@ -273,24 +257,25 @@ unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() { // Return the flat numbers of threads computing independent scan results. unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() { unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp(); - auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding()); warpsPerCTA[getAxis()] = 1; unsigned numParallelWarpsPerCTA = product(warpsPerCTA); return numParallelThreadsPerWarp * numParallelWarpsPerCTA; } unsigned ScanLoweringHelper::getAxisNumWarps() { - return getWarpsPerCTA(getEncoding())[getAxis()]; + return triton::gpu::getWarpsPerCTA(getEncoding())[getAxis()]; } unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() { - return getWarpsPerCTAWithUniqueData(getEncoding(), getShape())[getAxis()]; + return triton::gpu::getWarpsPerCTAWithUniqueData(getEncoding(), + getShape())[getAxis()]; } unsigned ScanLoweringHelper::getAxisNumBlocks() { - auto sizePerThreads = getSizePerThread(getEncoding()); - auto threadsPerWarp = getThreadsPerWarp(getEncoding()); - auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + auto sizePerThreads = triton::gpu::getSizePerThread(getEncoding()); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding()); unsigned axis = getAxis(); return ceil( getShape()[axis], @@ -298,9 +283,9 @@ unsigned ScanLoweringHelper::getAxisNumBlocks() { } unsigned ScanLoweringHelper::getNonAxisNumBlocks() { - auto sizePerThreads = getSizePerThread(getEncoding()); - auto threadsPerWarp = getThreadsPerWarp(getEncoding()); - auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + auto sizePerThreads = triton::gpu::getSizePerThread(getEncoding()); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding()); unsigned axis = getAxis(); unsigned numBlocks = 1; for (unsigned i = 0; i < sizePerThreads.size(); i++) { @@ -316,14 +301,14 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() { bool ScanLoweringHelper::isSupported() { // TODO: Support the following cases: // 1. Scan on non-blocking encodings - if (!isa(getEncoding())) + if (!isa(getEncoding())) return false; return true; } unsigned ScanLoweringHelper::getScratchSizeInElems() { auto mod = scanOp->getParentOfType(); - unsigned numWarps = TritonGPUDialect::getNumWarps(mod); + unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); unsigned numNonAxisElementsPerWarp = getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread(); unsigned numElements = numWarps * numNonAxisElementsPerWarp * @@ -380,12 +365,12 @@ getReshapeDecomposition(ArrayRef srcShape, return ret; } -BlockedEncodingAttr ScanLoweringHelper::getEncoding() { - return cast(srcEncoding); +triton::gpu::BlockedEncodingAttr ScanLoweringHelper::getEncoding() { + return cast(srcEncoding); } unsigned ScanLoweringHelper::getAxisElementStride() { - auto order = getOrder(getEncoding()); + auto order = triton::gpu::getOrder(getEncoding()); unsigned stride = 1; for (unsigned dim : order) { if (dim == getAxis()) @@ -396,7 +381,7 @@ unsigned ScanLoweringHelper::getAxisElementStride() { } unsigned ScanLoweringHelper::getAxisThreadStride() { - auto order = getOrder(getEncoding()); + auto order = triton::gpu::getOrder(getEncoding()); unsigned stride = 1; for (unsigned dim : order) { if (dim == getAxis()) @@ -407,11 +392,11 @@ unsigned ScanLoweringHelper::getAxisThreadStride() { } unsigned ScanLoweringHelper::getAxisBlockStride() { - auto order = getOrder(getEncoding()); + auto order = triton::gpu::getOrder(getEncoding()); unsigned stride = 1; - auto sizePerThreads = getSizePerThread(getEncoding()); - auto threadsPerWarp = getThreadsPerWarp(getEncoding()); - auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + auto sizePerThreads = triton::gpu::getSizePerThread(getEncoding()); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding()); for (unsigned dim : order) { if (dim == getAxis()) return stride; @@ -428,7 +413,8 @@ bool maybeSharedAllocationOp(Operation *op) { // query the memory effects of the op. auto *dialect = op->getDialect(); return dialect && - (dialect->getTypeID() == TypeID::get() || + (dialect->getTypeID() == + TypeID::get() || dialect->getTypeID() == TypeID::get() || dialect->getTypeID() == TypeID::get() || @@ -513,10 +499,10 @@ bool supportMMA(triton::DotOp op, int version) { if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) return false; auto retType = op.getType(); - auto retShapePerCTA = getShapePerCTA(retType); + auto retShapePerCTA = triton::gpu::getShapePerCTA(retType); auto rank = retShapePerCTA.size(); auto mod = op->getParentOfType(); - int numWarps = TritonGPUDialect::getNumWarps(mod); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); // TODO(Keren): for now, fallback to MMAv2 if handling batch matmul. if (rank == 3) return false; @@ -535,7 +521,8 @@ bool supportMMA(triton::DotOp op, int version) { } } if (aElemTy.isF32() && bElemTy.isF32()) { - return op.getInputPrecision() == triton::InputPrecision::TF32 && version >= 2; + return op.getInputPrecision() == triton::InputPrecision::TF32 && + version >= 2; } return supportMMA(op.getA(), version) && supportMMA(op.getB(), version); } @@ -557,8 +544,10 @@ bool supportMMA(Value value, int version) { } bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { - auto mfmaLayout = dyn_cast(srcTy.getEncoding()); - auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + auto mfmaLayout = + dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = + dyn_cast(dstTy.getEncoding()); if (mfmaLayout == nullptr || dotOperandLayout == nullptr) return false; // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is @@ -573,8 +562,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { } static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) { - auto src = dyn_cast(srcEncoding); - auto dst = dyn_cast(dstEncoding); + auto src = dyn_cast(srcEncoding); + auto dst = dyn_cast(dstEncoding); if (!src || !dst) return false; // when #mma = MmaEncoding @@ -590,8 +579,10 @@ bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { // For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy) { - auto mmaLayout = dyn_cast(srcTy.getEncoding()); - auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + auto mmaLayout = + dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = + dyn_cast(dstTy.getEncoding()); if (!mmaLayout || !dotOperandLayout) { return false; } @@ -605,13 +596,13 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { MLIRContext *ctx = srcTy.getContext(); - std::optional srcLayout = - toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - std::optional dstLayout = - toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + std::optional srcLayout = + triton::gpu::toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + std::optional dstLayout = + triton::gpu::toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); if (srcLayout.has_value() && dstLayout.has_value()) { // comp describes the layout function for converting from src to dst. - LinearLayout comp = srcLayout->invertAndCompose(*dstLayout); + triton::LinearLayout comp = srcLayout->invertAndCompose(*dstLayout); StringAttr kLane = StringAttr::get(ctx, "lane"); StringAttr kWarp = StringAttr::get(ctx, "warp"); StringAttr kBlock = StringAttr::get(ctx, "block"); @@ -621,12 +612,12 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { // // TODO(jlebar): Remove the kLane layout once we add support for // shuffle-based layout conversions in ConvertLayoutToLLVM. - if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kLane), - kLane, kLane) * - LinearLayout::identity1D(comp.getInDimSize(kWarp), - kWarp, kWarp) * - LinearLayout::identity1D(comp.getInDimSize(kBlock), - kBlock, kBlock)) + if (comp.divideRight(triton::LinearLayout::identity1D( + comp.getInDimSize(kLane), kLane, kLane) * + triton::LinearLayout::identity1D( + comp.getInDimSize(kWarp), kWarp, kWarp) * + triton::LinearLayout::identity1D( + comp.getInDimSize(kBlock), kBlock, kBlock)) .has_value()) { return false; } @@ -644,8 +635,10 @@ bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { return true; // dot_op = #mma // when #mma = MmaEncoding - auto mmaLayout = dyn_cast(srcTy.getEncoding()); - auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + auto mmaLayout = + dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = + dyn_cast(dstTy.getEncoding()); return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 && mmaLayout.getWarpsPerCTA()[1] == 1 && dotOperandLayout.getOpIdx() == 0 && @@ -872,13 +865,13 @@ std::unique_ptr createDataFlowSolver() { return solver; } -static MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { +static triton::MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { - if (auto makeTensorPtrOp = dyn_cast(op)) { + if (auto makeTensorPtrOp = dyn_cast(op)) { return makeTensorPtrOp; } - if (auto advanceOp = dyn_cast(op)) { + if (auto advanceOp = dyn_cast(op)) { return getMakeTensorPtrOp(advanceOp.getPtr()); } @@ -898,7 +891,7 @@ static MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { llvm_unreachable("Unable to getMakeTensorPtr()"); } -MakeTensorPtrOp getMakeTensorPtrOp(Value v) { +triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v) { using BranchOps = llvm::SetVector>; llvm::DenseMap blockToCFOps; auto moduleOp = diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index ddb4c680b3a4..7ebfa9337365 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -21,14 +21,6 @@ namespace mlir { -// We only "import" the symbols that we need to avoid name conflicts. -using triton::AxisInfo; -using triton::DialectInferLayoutInterface; -using triton::JoinOp; -using triton::ModuleAxisInfoAnalysis; -using triton::PointerType; -using triton::SplitOp; - SmallVector mmaVersionToInstrShape(int version, const ArrayRef &shape, RankedTensorType type, @@ -109,18 +101,19 @@ Value getMemAccessPtr(Operation *op) { unsigned getElementBitWidth(RankedTensorType type) { auto typeForMem = - isa(type.getElementType()) - ? cast(type.getElementType()).getPointeeType() + isa(type.getElementType()) + ? cast(type.getElementType()).getPointeeType() : type.getElementType(); return typeForMem.getIntOrFloatBitWidth(); } -unsigned getNumElementsPerThread(Operation *op, SmallVector order, - ModuleAxisInfoAnalysis &axisInfoAnalysis) { +unsigned +getNumElementsPerThread(Operation *op, SmallVector order, + triton::ModuleAxisInfoAnalysis &axisInfoAnalysis) { Value val = getMemAccessPtr(op); auto ty = cast(val.getType()); auto shapePerCTA = triton::gpu::getShapePerCTA(ty); - AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); + triton::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); unsigned elemNumBits = getElementBitWidth(ty); unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); @@ -308,10 +301,11 @@ static std::optional inferDstEncoding(triton::ExpandDimsOp op, return sliceEncoding.getParent(); } -static std::optional inferDstEncoding(JoinOp op, Attribute srcEnc) { +static std::optional inferDstEncoding(triton::JoinOp op, + Attribute srcEnc) { Attribute dstEnc; if (srcEnc.getDialect() - .getRegisteredInterface() + .getRegisteredInterface() ->inferJoinOpEncoding(srcEnc, dstEnc, /*loc=*/std::nullopt) .succeeded()) { @@ -320,10 +314,11 @@ static std::optional inferDstEncoding(JoinOp op, Attribute srcEnc) { return std::nullopt; } -static std::optional inferDstEncoding(SplitOp op, Attribute srcEnc) { +static std::optional inferDstEncoding(triton::SplitOp op, + Attribute srcEnc) { Attribute dstEnc; if (srcEnc.getDialect() - .getRegisteredInterface() + .getRegisteredInterface() ->inferSplitOpEncoding(srcEnc, dstEnc, /*loc=*/std::nullopt) .succeeded()) { @@ -348,11 +343,12 @@ static std::optional inferSrcEncoding(triton::ExpandDimsOp op, encoding); } -static std::optional inferSrcEncoding(JoinOp op, Attribute dstEnc) { +static std::optional inferSrcEncoding(triton::JoinOp op, + Attribute dstEnc) { // Split is the inverse of join. Attribute srcEnc; if (dstEnc.getDialect() - .getRegisteredInterface() + .getRegisteredInterface() ->inferSplitOpEncoding(dstEnc, srcEnc, /*loc=*/std::nullopt) .succeeded()) { return srcEnc; @@ -360,11 +356,12 @@ static std::optional inferSrcEncoding(JoinOp op, Attribute dstEnc) { return std::nullopt; } -static std::optional inferSrcEncoding(SplitOp op, Attribute dstEnc) { +static std::optional inferSrcEncoding(triton::SplitOp op, + Attribute dstEnc) { // Join is the inverse of split. Attribute srcEnc; if (dstEnc.getDialect() - .getRegisteredInterface() + .getRegisteredInterface() ->inferJoinOpEncoding(dstEnc, srcEnc, /*loc=*/std::nullopt) .succeeded()) { return srcEnc;