Skip to content

Commit

Permalink
[TRANSFORM] Pipeline fixes (triton-lang#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Jul 26, 2023
1 parent 5a0a3f8 commit da11a2d
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 61 deletions.
15 changes: 7 additions & 8 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,17 @@ class Allocation {
BufferKind kind;
BufferId id;
size_t size;
size_t alignment = 4;
size_t offset = 0;
size_t alignment;
size_t offset;

bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }

BufferT() : BufferT(BufferKind::Explicit) {}
BufferT(BufferKind kind)
: kind(kind), id(InvalidBufferId), size(0), offset(0) {}
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
BufferT(BufferKind kind, size_t size, size_t offset)
: kind(kind), id(nextId++), size(size), offset(offset) {}
BufferT() : BufferT(BufferKind::Explicit, 0) {}
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
size_t offset = 0)
: kind(kind), id(nextId++), size(size), alignment(alignment),
offset(offset) {}
};

/// Op -> Scratch Buffer
Expand Down
6 changes: 5 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,12 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
OpBuilder builder) {
if (!layoutMap.count(ptr))
return;

// Convert operands
// For load/store with tensor pointers, we don't have to change the
// operands' type, we do this by changing the outputs' type of
// `make_tensor_ptr`
auto convertType = layoutMap.lookup(ptr);
// convert operands
SmallVector<Value, 4> newArgs;
for (auto operand : op->getOperands()) {
auto tensorType = operand.getType().dyn_cast<RankedTensorType>();
Expand Down
123 changes: 75 additions & 48 deletions lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ class LoopPipeliner {
// exist in the base pipeliner
void checkOpShareBarriers(SetVector<Operation *> &ops);
int numLoadsRequireAsyncWait = 0;
int numLoadsRequireMBarrier = 0;
Value curPhase;

/// Iterator values
Value pipelineIterIdx;
Expand Down Expand Up @@ -307,25 +309,29 @@ LogicalResult LoopPipeliner::collectOps(SetVector<Operation *> &ops) {
// operations in the loop body block. Nested blocks are handled separately.
for (Operation &op : forOp)
if (auto loadOp = dyn_cast<tt::LoadOp>(&op)) {
auto ptr = loadOp.getPtr();
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);

if (auto mask = loadOp.getMask())
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));

auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy || tensorTy.getRank() < 2)
continue;
auto ty =
tensorTy.getElementType().cast<tt::PointerType>().getPointeeType();
unsigned width = vec * ty.getIntOrFloatBitWidth();
// We do not pipeline all loads for the following reasons:
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16.
// 2. It's likely that pipling small loads won't offer much performance
// improvement and may even hurt performance by increasing register
// pressure.
if (width >= 32)
if (isLoadFromTensorPtr(loadOp)) {
ops.insert(loadOp);
} else {
auto ptr = loadOp.getPtr();
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
if (auto mask = loadOp.getMask())
vec =
std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));

auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy || tensorTy.getRank() < 2)
continue;
auto ty =
tensorTy.getElementType().cast<tt::PointerType>().getPointeeType();
unsigned width = vec * ty.getIntOrFloatBitWidth();
// We do not pipeline all loads for the following reasons:
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16.
// 2. It's likely that pipling small loads won't offer much performance
// improvement and may even hurt performance by increasing register
// pressure.
if (width >= 32)
ops.insert(loadOp);
}
}

if (ops.empty())
Expand Down Expand Up @@ -453,6 +459,8 @@ LogicalResult LoopPipeliner::checkOpUses(SetVector<Operation *> &ops) {
validLoads.insert(loadOp);
if (!isLoadFromTensorPtr(loadOp))
numLoadsRequireAsyncWait++;
else
numLoadsRequireMBarrier++;
}
}
}
Expand Down Expand Up @@ -963,6 +971,7 @@ void LoopPipeliner::emitPrologue() {
loadsExtract[loadOp] = extractSlice;
}
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
curPhase = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 1);
}

void LoopPipeliner::emitEpilogue() {
Expand Down Expand Up @@ -1017,6 +1026,8 @@ SmallVector<Value> LoopPipeliner::collectNewLoopArgs() {
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
newLoopArgs.push_back(pipelineIterIdx);
newLoopArgs.push_back(loopIterIdx);
if (numLoadsRequireMBarrier > 0)
newLoopArgs.push_back(curPhase);

return newLoopArgs;
}
Expand Down Expand Up @@ -1057,18 +1068,29 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef<Value> newLoopArgs,
}
} else if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
if (isLoadFromTensorPtr(loadOp)) {
auto it =
std::find(validLoads.begin(), validLoads.end(), op.getOperand(0));
// XXX(Keren): The comparison operator using std::find on tensor ptr
// doesn't work as expected
auto operand = loadOp.getPtr();
auto tensorTy =
operand.getType().cast<tt::PointerType>().getPointeeType();
auto loadArgIdx = 0;
for (auto validLoad : validLoads) {
auto defOp = cast<tt::LoadOp>(validLoad.getDefiningOp());
if (isLoadFromTensorPtr(defOp)) {
auto validOperand = defOp.getOperand(0);
auto validTensorTy =
validOperand.getType().cast<tt::PointerType>().getPointeeType();
if (tensorTy == validTensorTy)
break;
}
loadArgIdx++;
}
// We replace the use new load use with a convert layout
auto loadArgIdx = std::distance(validLoads.begin(), it);
Value curIV = newLoopArgs[ivIdx];
Value step = newForOp.getStep();
Value upperBound = newForOp.getUpperBound();
Value oneVal =
builder.create<arith::ConstantIntOp>(loopIterIdx.getLoc(), 1,
/*bitWidth=*/32);
Value curPhase = builder.create<arith::AndIOp>(loopIterIdx.getLoc(),
loopIterIdx, oneVal);
loopIterIdx = newForOp.getRegionIterArgs()[ivIdx + 2];
curPhase = newForOp.getRegionIterArgs()[ivIdx + 3];

// consumer_relase, emitted after the last consumer
// 'the last consumer' might be updated in the following Phase_1 since
Expand All @@ -1092,10 +1114,8 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef<Value> newLoopArgs,
}

// consumer_wait, emitted before the first consumer
auto firstConsumer = getFirstUser(mapping.lookup(loadOp));
mapping.lookup(loadOp).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]);
mapping.lookup(loadOp).getDefiningOp()->erase();
auto firstConsumer = getFirstUser(loadOp);
mapping.map(loadOp, newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]);

// If current load can reuse barriers shared by previous load, then we
// do nothing.
Expand All @@ -1114,6 +1134,20 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef<Value> newLoopArgs,
cloneWithInferType(builder, &op, mapping);
}

// Remove redundant conversions
// e.g., %145 = triton_gpu.convert_layout %arg15 : (tensor<128x64xf16,
// #shared1>) -> tensor<128x64xf16, #shared1>
for (Operation &op : newForOp.getBody()->without_terminator()) {
if (auto convert_layout = dyn_cast<ttg::ConvertLayoutOp>(op)) {
auto result = op.getResult(0);
auto cvtDstTy = result.getType();
auto operand = convert_layout.getOperand();
auto tensorTy = operand.getType();
if (cvtDstTy == tensorTy)
result.replaceAllUsesWith(operand);
}
}

return newForOp;
}

Expand Down Expand Up @@ -1235,14 +1269,6 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
newMask = nextMapping.lookupOrDefault(loadOp.getMask());
} else {
if (loadOp.getPtr().getType().isa<tt::PointerType>())
newMask = nullptr;
// XXX(Keren): might be wrong for tma
// else
// newMask = builder.create<tt::SplatOp>(
// loadOp.getLoc(),
// mlir::tt::getI1SameShape(loadOp.getType()), nextLoopCond);
}
Value insertedVal;
if (mode && isLoadFromTensorPtr(loadOp)) {
Expand Down Expand Up @@ -1307,7 +1333,7 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
loc, loadsBuffer[loadOp].getType(),
nextMapping.lookupOrDefault(loadOp.getPtr()),
newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()],
insertSliceIndex, fullBarrier, nextLoopCond,
insertSliceIndex, fullBarrier, newMask,
nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(),
loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0);
} else {
Expand Down Expand Up @@ -1368,13 +1394,14 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
pipelineIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
// FIXME(Keren): Reenable after tma is fixed
// curWaitIdx = builder.create<arith::AddIOp>(
// forOp.getLoc(), curWaitIdx,
// builder.create<arith::ConstantIntOp>(forOp.getLoc(), 1, 32));
// curPhase = builder.create<arith::XOrIOp>(
// forOp.getLoc(), curPhase,
// builder.create<arith::ConstantIntOp>(forOp.getLoc(), 1, 1));
if (numLoadsRequireMBarrier > 0) {
curPhase = newForOp.getRegionIterArgs()[ivIdx + 3];
Value nextPhase = builder.create<arith::XOrIOp>(
forOp.getLoc(), curPhase,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 1));
curPhase = getBoundedIterationValue(builder, nextLoopIterIdx, numStagesVal,
curPhase, nextPhase);
}
}

void LoopPipeliner::finalizeYield(scf::ForOp newForOp, OpBuilder &builder) {
Expand All @@ -1395,8 +1422,8 @@ void LoopPipeliner::finalizeYield(scf::ForOp newForOp, OpBuilder &builder) {
yieldValues.push_back(nextIV);
yieldValues.push_back(pipelineIterIdx);
yieldValues.push_back(loopIterIdx);
// yieldValues.push_back(curWaitIdx);
// yieldValues.push_back(curPhase);
if (numLoadsRequireMBarrier > 0)
yieldValues.push_back(curPhase);

builder.setInsertionPointToEnd(newForOp.getBody());
builder.create<scf::YieldOp>(yieldOp->getLoc(), yieldValues);
Expand Down
8 changes: 4 additions & 4 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2258,13 +2258,13 @@ def kernel(X, stride_xm, stride_xk,
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if in_dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif in_dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif in_dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
elif out_dtype == tl.float16:
assert 'mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16' in ptx
assert 'mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16' in ptx


@pytest.mark.parametrize('in_dtype', ['float32'])
Expand Down

0 comments on commit da11a2d

Please sign in to comment.