diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index a1d103b86810..6370eba55b32 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -148,18 +148,17 @@ class Allocation { BufferKind kind; BufferId id; size_t size; - size_t alignment = 4; - size_t offset = 0; + size_t alignment; + size_t offset; bool operator==(const BufferT &other) const { return id == other.id; } bool operator<(const BufferT &other) const { return id < other.id; } - BufferT() : BufferT(BufferKind::Explicit) {} - BufferT(BufferKind kind) - : kind(kind), id(InvalidBufferId), size(0), offset(0) {} - BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {} - BufferT(BufferKind kind, size_t size, size_t offset) - : kind(kind), id(nextId++), size(size), offset(offset) {} + BufferT() : BufferT(BufferKind::Explicit, 0) {} + BufferT(BufferKind kind, size_t size, size_t alignment = 4, + size_t offset = 0) + : kind(kind), id(nextId++), size(size), alignment(alignment), + offset(offset) {} }; /// Op -> Scratch Buffer diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index e4bb62901706..b7cba6cbf5dc 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -145,8 +145,12 @@ struct CoalescePass : public TritonGPUCoalesceBase { OpBuilder builder) { if (!layoutMap.count(ptr)) return; + + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` auto convertType = layoutMap.lookup(ptr); - // convert operands SmallVector newArgs; for (auto operand : op->getOperands()) { auto tensorType = operand.getType().dyn_cast(); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 27589cb91f2d..c60603fc3ddb 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -159,6 +159,8 @@ class LoopPipeliner { // exist in the base pipeliner void checkOpShareBarriers(SetVector &ops); int numLoadsRequireAsyncWait = 0; + int numLoadsRequireMBarrier = 0; + Value curPhase; /// Iterator values Value pipelineIterIdx; @@ -307,25 +309,29 @@ LogicalResult LoopPipeliner::collectOps(SetVector &ops) { // operations in the loop body block. Nested blocks are handled separately. for (Operation &op : forOp) if (auto loadOp = dyn_cast(&op)) { - auto ptr = loadOp.getPtr(); - unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); - - if (auto mask = loadOp.getMask()) - vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); - - auto tensorTy = ptr.getType().dyn_cast(); - if (!tensorTy || tensorTy.getRank() < 2) - continue; - auto ty = - tensorTy.getElementType().cast().getPointeeType(); - unsigned width = vec * ty.getIntOrFloatBitWidth(); - // We do not pipeline all loads for the following reasons: - // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16. - // 2. It's likely that pipling small loads won't offer much performance - // improvement and may even hurt performance by increasing register - // pressure. - if (width >= 32) + if (isLoadFromTensorPtr(loadOp)) { ops.insert(loadOp); + } else { + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = + std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy || tensorTy.getRank() < 2) + continue; + auto ty = + tensorTy.getElementType().cast().getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + if (width >= 32) + ops.insert(loadOp); + } } if (ops.empty()) @@ -453,6 +459,8 @@ LogicalResult LoopPipeliner::checkOpUses(SetVector &ops) { validLoads.insert(loadOp); if (!isLoadFromTensorPtr(loadOp)) numLoadsRequireAsyncWait++; + else + numLoadsRequireMBarrier++; } } } @@ -963,6 +971,7 @@ void LoopPipeliner::emitPrologue() { loadsExtract[loadOp] = extractSlice; } loopIterIdx = builder.create(iv.getLoc(), 0, 32); + curPhase = builder.create(iv.getLoc(), 0, 1); } void LoopPipeliner::emitEpilogue() { @@ -1017,6 +1026,8 @@ SmallVector LoopPipeliner::collectNewLoopArgs() { newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]); newLoopArgs.push_back(pipelineIterIdx); newLoopArgs.push_back(loopIterIdx); + if (numLoadsRequireMBarrier > 0) + newLoopArgs.push_back(curPhase); return newLoopArgs; } @@ -1057,18 +1068,29 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, } } else if (auto loadOp = dyn_cast(op)) { if (isLoadFromTensorPtr(loadOp)) { - auto it = - std::find(validLoads.begin(), validLoads.end(), op.getOperand(0)); + // XXX(Keren): The comparison operator using std::find on tensor ptr + // doesn't work as expected + auto operand = loadOp.getPtr(); + auto tensorTy = + operand.getType().cast().getPointeeType(); + auto loadArgIdx = 0; + for (auto validLoad : validLoads) { + auto defOp = cast(validLoad.getDefiningOp()); + if (isLoadFromTensorPtr(defOp)) { + auto validOperand = defOp.getOperand(0); + auto validTensorTy = + validOperand.getType().cast().getPointeeType(); + if (tensorTy == validTensorTy) + break; + } + loadArgIdx++; + } // We replace the use new load use with a convert layout - auto loadArgIdx = std::distance(validLoads.begin(), it); Value curIV = newLoopArgs[ivIdx]; Value step = newForOp.getStep(); Value upperBound = newForOp.getUpperBound(); - Value oneVal = - builder.create(loopIterIdx.getLoc(), 1, - /*bitWidth=*/32); - Value curPhase = builder.create(loopIterIdx.getLoc(), - loopIterIdx, oneVal); + loopIterIdx = newForOp.getRegionIterArgs()[ivIdx + 2]; + curPhase = newForOp.getRegionIterArgs()[ivIdx + 3]; // consumer_relase, emitted after the last consumer // 'the last consumer' might be updated in the following Phase_1 since @@ -1092,10 +1114,8 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, } // consumer_wait, emitted before the first consumer - auto firstConsumer = getFirstUser(mapping.lookup(loadOp)); - mapping.lookup(loadOp).replaceAllUsesWith( - newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); - mapping.lookup(loadOp).getDefiningOp()->erase(); + auto firstConsumer = getFirstUser(loadOp); + mapping.map(loadOp, newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); // If current load can reuse barriers shared by previous load, then we // do nothing. @@ -1114,6 +1134,20 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, cloneWithInferType(builder, &op, mapping); } + // Remove redundant conversions + // e.g., %145 = triton_gpu.convert_layout %arg15 : (tensor<128x64xf16, + // #shared1>) -> tensor<128x64xf16, #shared1> + for (Operation &op : newForOp.getBody()->without_terminator()) { + if (auto convert_layout = dyn_cast(op)) { + auto result = op.getResult(0); + auto cvtDstTy = result.getType(); + auto operand = convert_layout.getOperand(); + auto tensorTy = operand.getType(); + if (cvtDstTy == tensorTy) + result.replaceAllUsesWith(operand); + } + } + return newForOp; } @@ -1235,14 +1269,6 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) nextMapping.map(mask, newMask); newMask = nextMapping.lookupOrDefault(loadOp.getMask()); - } else { - if (loadOp.getPtr().getType().isa()) - newMask = nullptr; - // XXX(Keren): might be wrong for tma - // else - // newMask = builder.create( - // loadOp.getLoc(), - // mlir::tt::getI1SameShape(loadOp.getType()), nextLoopCond); } Value insertedVal; if (mode && isLoadFromTensorPtr(loadOp)) { @@ -1307,7 +1333,7 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, loc, loadsBuffer[loadOp].getType(), nextMapping.lookupOrDefault(loadOp.getPtr()), newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], - insertSliceIndex, fullBarrier, nextLoopCond, + insertSliceIndex, fullBarrier, newMask, nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); } else { @@ -1368,13 +1394,14 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, pipelineIterIdx = builder.create( nextIV.getLoc(), pipelineIterIdx, builder.create(nextIV.getLoc(), 1, 32)); - // FIXME(Keren): Reenable after tma is fixed - // curWaitIdx = builder.create( - // forOp.getLoc(), curWaitIdx, - // builder.create(forOp.getLoc(), 1, 32)); - // curPhase = builder.create( - // forOp.getLoc(), curPhase, - // builder.create(forOp.getLoc(), 1, 1)); + if (numLoadsRequireMBarrier > 0) { + curPhase = newForOp.getRegionIterArgs()[ivIdx + 3]; + Value nextPhase = builder.create( + forOp.getLoc(), curPhase, + builder.create(nextIV.getLoc(), 1, 1)); + curPhase = getBoundedIterationValue(builder, nextLoopIterIdx, numStagesVal, + curPhase, nextPhase); + } } void LoopPipeliner::finalizeYield(scf::ForOp newForOp, OpBuilder &builder) { @@ -1395,8 +1422,8 @@ void LoopPipeliner::finalizeYield(scf::ForOp newForOp, OpBuilder &builder) { yieldValues.push_back(nextIV); yieldValues.push_back(pipelineIterIdx); yieldValues.push_back(loopIterIdx); - // yieldValues.push_back(curWaitIdx); - // yieldValues.push_back(curPhase); + if (numLoadsRequireMBarrier > 0) + yieldValues.push_back(curPhase); builder.setInsertionPointToEnd(newForOp.getBody()); builder.create(yieldOp->getLoc(), yieldValues); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 233bb85dfadf..fb572d369207 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2258,13 +2258,13 @@ def kernel(X, stride_xm, stride_xk, assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx if in_dtype == 'float32' and allow_tf32: - assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx + assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx elif in_dtype == 'float32' and allow_tf32: - assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx + assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx elif in_dtype == 'int8': - assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx elif out_dtype == tl.float16: - assert 'mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16' in ptx + assert 'mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16' in ptx @pytest.mark.parametrize('in_dtype', ['float32'])