Skip to content

Commit

Permalink
Pad fusion bufferization workaround. (iree-org#12425)
Browse files Browse the repository at this point in the history
It seems like handling the code generated by the tiling of pad operations needs more work in bufferization. To unblock the work of handling pad operations natively in IREE,
iree-org#11273 (comment) is implemented here as a workaround.

To ensure bufferization without allocation, yields of the then and else branch and the result of the scf.if are all tied together. If the then and else come from different bindings, then this would be illegal (because a copy is needed). This example led to adding more constraints on what sets can be merged during the
BufferizationAnalysis to avoid merging sets that have constants or have two different interface_bindings.

benchmarks: x86_64, cuda
  • Loading branch information
MaheshRavishankar authored and qedawkins committed Apr 2, 2023
1 parent 64fb151 commit f9d98c7
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 39 deletions.
85 changes: 61 additions & 24 deletions compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,32 @@ static OpType getEquivalentOpOfType(Value value, BufferizationPlan &plan) {
return equivalentOp;
}

/// Check if two sets can be merged based on what operations exist in that set.
static bool canSetsBeMerged(Value v1, Value v2, BufferizationPlan &plan) {
// Dont merge two sets if one of the sets is a constant.
if (getEquivalentOpOfType<arith::ConstantOp>(v1, plan) ||
getEquivalentOpOfType<arith::ConstantOp>(v2, plan)) {
return false;
}
auto v1InterfaceBinding =
getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(v1, plan);
auto v2InterfaceBinding =
getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(v2, plan);
// If any of these sets do not have a interface binding, they can be merged.
if (!v1InterfaceBinding || !v2InterfaceBinding) {
return true;
}
if (v1InterfaceBinding.getSet() != v2InterfaceBinding.getSet() ||
v1InterfaceBinding.getBinding() != v2InterfaceBinding.getBinding() ||
v1InterfaceBinding.getByteOffset() !=
v2InterfaceBinding.getByteOffset()) {
// If the set, binding or offsets are different, map these to different
// memrefs.
return false;
}
return true;
}

/// Returns true if the value and target of a `flow.dispatch.tensor.store`
/// operation can be added to the same equivalence set. This can be done only if
/// - The `value` is not from a equivalence set that contains a read-only
Expand All @@ -130,32 +156,24 @@ static OpType getEquivalentOpOfType(Value value, BufferizationPlan &plan) {
/// `hal.interface.binding.subspan` op.'
static bool canSetStoreValueAndTargetAsEquivalent(
IREE::Flow::DispatchTensorStoreOp storeOp, BufferizationPlan &plan) {
Value value = storeOp.getValue();
Value target = storeOp.getTarget();
auto targetInterfaceOp =
getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(target, plan);
assert(targetInterfaceOp);
if (auto valueConstantOp =
getEquivalentOpOfType<arith::ConstantOp>(value, plan)) {
if (!canSetsBeMerged(storeOp.getValue(), storeOp.getTarget(), plan)) {
return false;
}
if (auto valueInterfaceOp =
getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(value,
plan)) {
if (targetInterfaceOp.getBinding() != valueInterfaceOp.getBinding() ||
targetInterfaceOp.getByteOffset() != valueInterfaceOp.getByteOffset()) {
// If the binding and offsets are different, map these to different
// memrefs.
return false;
}
// If the binding and offsets are the same, make sure that the
// !flow.dispatch.tensor is read-write.
auto sourceType =
valueInterfaceOp.getType().dyn_cast<IREE::Flow::DispatchTensorType>();
return sourceType &&
sourceType.getAccess() == IREE::Flow::TensorAccess::ReadWrite;
auto valueInterfaceBinding =
getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(
storeOp.getValue(), plan);
auto targetInterfaceBinding =
getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(
storeOp.getTarget(), plan);
if (!valueInterfaceBinding || !targetInterfaceBinding) {
return true;
}
return true;
// If the binding and offsets are the same, make sure that the
// !flow.dispatch.tensor is read-write.
auto sourceType = valueInterfaceBinding.getType()
.dyn_cast<IREE::Flow::DispatchTensorType>();
return sourceType &&
sourceType.getAccess() == IREE::Flow::TensorAccess::ReadWrite;
}

/// Tries to add the `value` and `target` to the same equivalence class.
Expand Down Expand Up @@ -462,6 +480,25 @@ static void tieOperandsForOperandFusion(linalg::LinalgOp linalgOp,
}
}

void BufferizationPlan::unionSets(Value v1, Value v2) {
if (!canSetsBeMerged(v1, v2, *this)) {
return;
}
// If one the sets was part of the store set, the store set
// needs to be updated to drop the all leaders from the store set
// and add the new leader to it.
Value leader1 = getLeaderValue(v1);
Value leader2 = getLeaderValue(v2);
bool insertNewStoreLeader =
storeLeaders.count(leader1) || storeLeaders.count(leader2);
storeLeaders.erase(leader1);
storeLeaders.erase(leader2);
mappedTensors.unionSets(getPointer(v1), getPointer(v2));
if (insertNewStoreLeader) {
storeLeaders.insert(getLeaderValue(v1));
}
}

void BufferizationPlan::dump() {
llvm::dbgs() << "BufferMappings : \n";
unsigned numSets = 0;
Expand Down Expand Up @@ -560,7 +597,7 @@ LogicalResult createTensorEquivalenceClasses(func::FuncOp funcOp,
.Case<scf::ForOp>(
[&](scf::ForOp forOp) { return analyseScfForOp(forOp, plan); })
.Case<scf::YieldOp, tensor::EmptyOp, tensor::DimOp, tensor::ExtractOp,
tensor::PadOp, bufferization::ToMemrefOp,
tensor::GenerateOp, tensor::PadOp, bufferization::ToMemrefOp,
bufferization::AllocTensorOp>(
[&](Operation *op) { return success(); })
.Default([&](Operation *op) -> LogicalResult {
Expand Down
18 changes: 3 additions & 15 deletions compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,9 @@ class BufferizationPlan {

void insert(Value v) { mappedTensors.insert(getPointer(v)); }

void unionSets(Value v1, Value v2) {
// If one the sets was part of the store set, the store set
// needs to be updated to drop the all leaders from the store set
// and add the new leader to it.
Value leader1 = getLeaderValue(v1);
Value leader2 = getLeaderValue(v2);
bool insertNewStoreLeader =
storeLeaders.count(leader1) || storeLeaders.count(leader2);
storeLeaders.erase(leader1);
storeLeaders.erase(leader2);
mappedTensors.unionSets(getPointer(v1), getPointer(v2));
if (insertNewStoreLeader) {
storeLeaders.insert(getLeaderValue(v1));
}
}
/// Union the sets containing `v1` and `v2`. Checks if the union can be
/// done and does nothing if union is invalid.
void unionSets(Value v1, Value v2);

/// Sets the equivalance class that contains `v` as the set that contains the
/// result tensor of the dispatch region (i.e. a tensor that is the `value`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,64 @@ struct RemoveCstOutsDependency
return success();
}
};

/// Add a pattern to switch
/// ```mlir
/// %0 = scf.if %cond {
/// ...
/// scf.yield %true
/// } else {
/// ...
/// scf.yield %false
/// }
/// flow.dispatch.tensor.store %0, %target, ...
/// ```
///
/// to
///
/// ```mlir
/// scf.if %cond {
/// ...
/// flow.dispatch.tensor.store %true, %target
/// } else {
/// ...
/// flow.dispatch.tensor.store %true, %target
/// }
/// ```
/// This is a workaround for #11273 while a proper fix lands.
struct SwitchStoreOfIfResultValue
: public OpRewritePattern<IREE::Flow::DispatchTensorStoreOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(IREE::Flow::DispatchTensorStoreOp storeOp,
PatternRewriter &rewriter) const override {
auto ifOp = storeOp.getValue().getDefiningOp<scf::IfOp>();
if (!ifOp) {
return rewriter.notifyMatchFailure(storeOp,
"store source is not an if statement");
}

auto resultNumber = storeOp.getValue().cast<OpResult>().getResultNumber();
auto moveStoreInsideBody = [&](Block *body) {
OpBuilder::InsertionGuard guard(rewriter);
auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
rewriter.setInsertionPoint(yieldOp);
auto yieldedVal = yieldOp.getOperand(resultNumber);
SliceAndDynamicDims sliceAndDynamicDims =
cloneOffsetsSizesAndStrides(rewriter, storeOp);
rewriter.create<IREE::Flow::DispatchTensorStoreOp>(
storeOp.getLoc(), yieldedVal, storeOp.getTarget(),
sliceAndDynamicDims.dynamicDims, sliceAndDynamicDims.offsets,
sliceAndDynamicDims.sizes, sliceAndDynamicDims.strides);
};

moveStoreInsideBody(&ifOp.getThenRegion().front());
moveStoreInsideBody(&ifOp.getElseRegion().front());
rewriter.eraseOp(storeOp);
return success();
}
};

} // namespace

void ConvertToDestinationPassingStylePass::runOnOperation() {
Expand Down Expand Up @@ -567,6 +625,14 @@ void ConvertToDestinationPassingStylePass::runOnOperation() {
return signalPassFailure();
}
}

{
RewritePatternSet patterns(context);
patterns.insert<SwitchStoreOfIfResultValue>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
}

std::unique_ptr<OperationPass<func::FuncOp>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,123 @@ func.func @multi_result_dispatches() {
// CHECK-SAME: outs(%[[RESULT0]] :
// CHECK: flow.dispatch.tensor.store %[[MATMUL]], %[[RESULT_BINDING1]]
// CHECK: flow.dispatch.tensor.store %[[GENERIC]], %[[RESULT_BINDING0]]

// -----

func.func @if_conversion() {
%0 = hal.interface.constant.load[0] : index
%offset = hal.interface.constant.load[1] : index
%size = hal.interface.constant.load[2] : index
%cond = hal.interface.constant.load[3] : i1
%result_offset = hal.interface.constant.load[4] : index
%then = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
: !flow.dispatch.tensor<readonly:tensor<?xf32>>{%0}
%else = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
: !flow.dispatch.tensor<readonly:tensor<?xf32>>{%0}
%result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
: !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%0}
%then_value = flow.dispatch.tensor.load %then, offsets = [%offset], sizes = [%size], strides = [1]
: !flow.dispatch.tensor<readonly:tensor<?xf32>>{%0} -> tensor<?xf32>
%else_value = flow.dispatch.tensor.load %else, offsets = [%offset], sizes = [%size], strides = [1]
: !flow.dispatch.tensor<readonly:tensor<?xf32>>{%0} -> tensor<?xf32>
%if = scf.if %cond -> (tensor<?xf32>) {
scf.yield %then_value : tensor<?xf32>
} else {
scf.yield %else_value : tensor<?xf32>
}
flow.dispatch.tensor.store %if, %result, offsets = [%result_offset], sizes = [%size], strides = [1]
: tensor<?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%0}
return
}
// CHECK-LABEL: func @if_conversion()
// CHECK-DAG: %[[S0:.+]] = hal.interface.constant.load[0]
// CHECK-DAG: %[[S1:.+]] = hal.interface.constant.load[2]
// CHECK-DAG: %[[COND:.+]] = hal.interface.constant.load[3]
// CHECK-DAG: %[[OFFSET:.+]] = hal.interface.constant.load[4]
// CHECK-DAG: %[[THEN_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
// CHECK-DAG: %[[ELSE_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
// CHECK-DAG: %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
// CHECK-DAG: %[[THEN:.+]] = flow.dispatch.tensor.load %[[THEN_BINDING]]
// CHECK-DAG: %[[ELSE:.+]] = flow.dispatch.tensor.load %[[ELSE_BINDING]]
// CHECK: scf.if %[[COND]] {
// CHECK-NEXT: flow.dispatch.tensor.store %[[THEN]], %[[RESULT_BINDING]]
// CHECK-SAME: offsets = [%[[OFFSET]]], sizes = [%[[S1]]]
// CHECK-SAME: flow.dispatch.tensor<writeonly:tensor<?xf32>>{%[[S0]]}
// CHECK-NEXT: } else {
// CHECK-NEXT: flow.dispatch.tensor.store %[[ELSE]], %[[RESULT_BINDING]]
// CHECK-SAME: offsets = [%[[OFFSET]]], sizes = [%[[S1]]]
// CHECK-SAME: flow.dispatch.tensor<writeonly:tensor<?xf32>>{%[[S0]]}
// CHECK-NEXT: }
// CHECK-NEXT: return

// -----

func.func @if_conversion_clone_offsets() {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%6, %7}
%11 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%8, %9}
%12 = affine.apply affine_map<()[s0, s1] -> (-s0 + s1 + (s0 ceildiv 16) * 16)>()[%6, %6]
%13 = affine.apply affine_map<()[s0, s1] -> (-s0 + s1 + (s0 ceildiv 16) * 16)>()[%7, %7]
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%14 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
%15 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
scf.for %arg0 = %14 to %12 step %15 {
%16 = affine.min affine_map<(d0)[s0, s1] -> (64, -d0 - s0 + s1 + (s0 ceildiv 16) * 16)>(%arg0)[%6, %6]
%17 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%18 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
scf.for %arg1 = %17 to %13 step %18 {
%19 = affine.min affine_map<(d0)[s0, s1] -> (64, -d0 - s0 + s1 + (s0 ceildiv 16) * 16)>(%arg1)[%7, %7]
%20 = affine.min affine_map<(d0)[s0] -> (s0, d0)>(%arg0)[%6]
%21 = affine.min affine_map<(d0, d1)[s0] -> (s0, d0 + d1)>(%arg0, %16)[%6]
%22 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%21, %20)
%23 = arith.cmpi eq, %22, %c0 : index
%24 = affine.apply affine_map<(d0, d1, d2) -> (d0 - d1 + d2)>(%16, %21, %20)
%25 = affine.min affine_map<(d0)[s0] -> (s0, d0)>(%arg1)[%7]
%26 = affine.min affine_map<(d0, d1)[s0] -> (s0, d0 + d1)>(%arg1, %19)[%7]
%27 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%26, %25)
%28 = arith.cmpi eq, %27, %c0 : index
%29 = arith.ori %28, %23 : i1
%30 = affine.apply affine_map<(d0, d1, d2) -> (d0 - d1 + d2)>(%19, %26, %25)
%31 = scf.if %29 -> (tensor<?x?xf32>) {
%generated = tensor.generate %16, %19 {
^bb0(%arg2: index, %arg3: index):
tensor.yield %cst : f32
} : tensor<?x?xf32>
scf.yield %generated : tensor<?x?xf32>
} else {
%34 = flow.dispatch.tensor.load %10, offsets = [%20, %25], sizes = [%22, %27], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%6, %7} -> tensor<?x?xf32>
%padded = tensor.pad %34 low[0, 0] high[%24, %30] {
^bb0(%arg2: index, %arg3: index):
tensor.yield %cst : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
scf.yield %padded : tensor<?x?xf32>
}
%32 = arith.index_castui %3 : i32 to index
%33 = arith.index_castui %4 : i32 to index
flow.dispatch.tensor.store %31, %11, offsets = [%arg0, %arg1], sizes = [%16, %19], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%32, %33}
}
}
return
}
// CHECK-LABEL: func @if_conversion_clone_offsets()
// CHECK: scf.if
// CHECK-NEXT: %[[GENERATED:.+]] = tensor.generate
// CHECK: flow.dispatch.tensor.store %[[GENERATED]]
// CHECK: else
// CHECK: %[[VAL:.+]] = flow.dispatch.tensor.load
// CHECK: %[[PADDED:.+]] = tensor.pad %[[VAL]]
// CHECK: flow.dispatch.tensor.store %[[PADDED]]

0 comments on commit f9d98c7

Please sign in to comment.