Skip to content

Commit

Permalink
Pad fusion bufferization workaround.
Browse files Browse the repository at this point in the history
It seems like handling the code generated by the tiling of pad
operations needs more work in bufferization. To unblock the work of
handling pad operations natively in IREE,
iree-org#11273 (comment)
is implemented here as a workaround.

To ensure bufferization without allocation, yields of the then and
else branch and the result of the `scf.if` are all tied together. If
the `then` and `else` come from different bindings, then this would be
illegal (because a copy is needed). This example led to adding more
constraints on what sets can be merged during the
`BufferizationAnalysis` to avoid merging sets that have constants or
have two different `interface_bindings`.
  • Loading branch information
Mahesh Ravishankar committed Feb 28, 2023
1 parent 17eafc9 commit b34880c
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 38 deletions.
83 changes: 60 additions & 23 deletions compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,32 @@ static OpType getEquivalentOpOfType(Value value, BufferizationPlan &plan) {
return equivalentOp;
}

/// Check if two sets can be merged based on what operations exist in that set.
static bool canSetsBeMerged(Value v1, Value v2, BufferizationPlan &plan) {
// Dont merge two sets if one of the sets is a constant.
if (getEquivalentOpOfType<arith::ConstantOp>(v1, plan) ||
getEquivalentOpOfType<arith::ConstantOp>(v2, plan)) {
return false;
}
auto v1InterfaceBinding =
getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(v1, plan);
auto v2InterfaceBinding =
getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(v2, plan);
// If any of these sets do not have a interface binding, they can be merged.
if (!v1InterfaceBinding || !v2InterfaceBinding) {
return true;
}
if (v1InterfaceBinding.getSet() != v2InterfaceBinding.getSet() ||
v1InterfaceBinding.getBinding() != v2InterfaceBinding.getBinding() ||
v1InterfaceBinding.getByteOffset() !=
v2InterfaceBinding.getByteOffset()) {
// If the set, binding or offsets are different, map these to different
// memrefs.
return false;
}
return true;
}

/// Returns true if the value and target of a `flow.dispatch.tensor.store`
/// operation can be added to the same equivalence set. This can be done only if
/// - The `value` is not from a equivalence set that contains a read-only
Expand All @@ -130,32 +156,24 @@ static OpType getEquivalentOpOfType(Value value, BufferizationPlan &plan) {
/// `hal.interface.binding.subspan` op.'
static bool canSetStoreValueAndTargetAsEquivalent(
IREE::Flow::DispatchTensorStoreOp storeOp, BufferizationPlan &plan) {
Value value = storeOp.getValue();
Value target = storeOp.getTarget();
auto targetInterfaceOp =
getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(target, plan);
assert(targetInterfaceOp);
if (auto valueConstantOp =
getEquivalentOpOfType<arith::ConstantOp>(value, plan)) {
if (!canSetsBeMerged(storeOp.getValue(), storeOp.getTarget(), plan)) {
return false;
}
if (auto valueInterfaceOp =
getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(value,
plan)) {
if (targetInterfaceOp.getBinding() != valueInterfaceOp.getBinding() ||
targetInterfaceOp.getByteOffset() != valueInterfaceOp.getByteOffset()) {
// If the binding and offsets are different, map these to different
// memrefs.
return false;
}
// If the binding and offsets are the same, make sure that the
// !flow.dispatch.tensor is read-write.
auto sourceType =
valueInterfaceOp.getType().dyn_cast<IREE::Flow::DispatchTensorType>();
return sourceType &&
sourceType.getAccess() == IREE::Flow::TensorAccess::ReadWrite;
auto valueInterfaceBinding =
getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(
storeOp.getValue(), plan);
auto targetInterfaceBinding =
getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(
storeOp.getTarget(), plan);
if (!valueInterfaceBinding || !targetInterfaceBinding) {
return true;
}
return true;
// If the binding and offsets are the same, make sure that the
// !flow.dispatch.tensor is read-write.
auto sourceType = valueInterfaceBinding.getType()
.dyn_cast<IREE::Flow::DispatchTensorType>();
return sourceType &&
sourceType.getAccess() == IREE::Flow::TensorAccess::ReadWrite;
}

/// Tries to add the `value` and `target` to the same equivalence class.
Expand Down Expand Up @@ -462,6 +480,25 @@ static void tieOperandsForOperandFusion(linalg::LinalgOp linalgOp,
}
}

void BufferizationPlan::unionSets(Value v1, Value v2) {
if (!canSetsBeMerged(v1, v2, *this)) {
return;
}
// If one the sets was part of the store set, the store set
// needs to be updated to drop the all leaders from the store set
// and add the new leader to it.
Value leader1 = getLeaderValue(v1);
Value leader2 = getLeaderValue(v2);
bool insertNewStoreLeader =
storeLeaders.count(leader1) || storeLeaders.count(leader2);
storeLeaders.erase(leader1);
storeLeaders.erase(leader2);
mappedTensors.unionSets(getPointer(v1), getPointer(v2));
if (insertNewStoreLeader) {
storeLeaders.insert(getLeaderValue(v1));
}
}

void BufferizationPlan::dump() {
llvm::dbgs() << "BufferMappings : \n";
unsigned numSets = 0;
Expand Down
18 changes: 3 additions & 15 deletions compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,9 @@ class BufferizationPlan {

void insert(Value v) { mappedTensors.insert(getPointer(v)); }

void unionSets(Value v1, Value v2) {
// If one the sets was part of the store set, the store set
// needs to be updated to drop the all leaders from the store set
// and add the new leader to it.
Value leader1 = getLeaderValue(v1);
Value leader2 = getLeaderValue(v2);
bool insertNewStoreLeader =
storeLeaders.count(leader1) || storeLeaders.count(leader2);
storeLeaders.erase(leader1);
storeLeaders.erase(leader2);
mappedTensors.unionSets(getPointer(v1), getPointer(v2));
if (insertNewStoreLeader) {
storeLeaders.insert(getLeaderValue(v1));
}
}
/// Union the sets containing `v1` and `v2`. Checks if the union can be
/// done and does nothing if union is invalid.
void unionSets(Value v1, Value v2);

/// Sets the equivalance class that contains `v` as the set that contains the
/// result tensor of the dispatch region (i.e. a tensor that is the `value`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,64 @@ struct RemoveCstOutsDependency
return success();
}
};

/// Add a pattern to switch
/// ```mlir
/// %0 = scf.if %cond {
/// ...
/// scf.yield %true
/// } else {
/// ...
/// scf.yield %false
/// }
/// flow.dispatch.tensor.store %0, %target, ...
/// ```
///
/// to
///
/// ```mlir
/// scf.if %cond {
/// ...
/// flow.dispatch.tensor.store %true, %target
/// } else {
/// ...
/// flow.dispatch.tensor.store %true, %target
/// }
/// ```
/// This is a workaround for #11273 while a proper fix lands.
struct SwitchStoreOfIfResultValue
: public OpRewritePattern<IREE::Flow::DispatchTensorStoreOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(IREE::Flow::DispatchTensorStoreOp storeOp,
PatternRewriter &rewriter) const override {
auto ifOp = storeOp.getValue().getDefiningOp<scf::IfOp>();
if (!ifOp) {
return rewriter.notifyMatchFailure(storeOp,
"store source is not an if statement");
}

auto resultNumber = storeOp.getValue().cast<OpResult>().getResultNumber();
auto moveStoreInsideBody = [&](Block *body) {
OpBuilder::InsertionGuard g2(rewriter);
auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
rewriter.setInsertionPoint(yieldOp);
auto yieldedVal = yieldOp.getOperand(resultNumber);
SliceAndDynamicDims sliceAndDynamicDims =
cloneOffsetsSizesAndStrides(rewriter, storeOp);
rewriter.create<IREE::Flow::DispatchTensorStoreOp>(
storeOp.getLoc(), yieldedVal, storeOp.getTarget(),
sliceAndDynamicDims.dynamicDims, sliceAndDynamicDims.offsets,
sliceAndDynamicDims.sizes, sliceAndDynamicDims.strides);
};

moveStoreInsideBody(&ifOp.getThenRegion().front());
moveStoreInsideBody(&ifOp.getElseRegion().front());
rewriter.eraseOp(storeOp);
return success();
}
};

} // namespace

void ConvertToDestinationPassingStylePass::runOnOperation() {
Expand Down Expand Up @@ -567,6 +625,14 @@ void ConvertToDestinationPassingStylePass::runOnOperation() {
return signalPassFailure();
}
}

{
RewritePatternSet patterns(context);
patterns.insert<SwitchStoreOfIfResultValue>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
}

std::unique_ptr<OperationPass<func::FuncOp>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,123 @@ func.func @multi_result_dispatches() {
// CHECK-SAME: outs(%[[RESULT0]] :
// CHECK: flow.dispatch.tensor.store %[[MATMUL]], %[[RESULT_BINDING1]]
// CHECK: flow.dispatch.tensor.store %[[GENERIC]], %[[RESULT_BINDING0]]

// -----

func.func @if_conversion() {
%0 = hal.interface.constant.load[0] : index
%offset = hal.interface.constant.load[1] : index
%size = hal.interface.constant.load[2] : index
%cond = hal.interface.constant.load[3] : i1
%result_offset = hal.interface.constant.load[4] : index
%then = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
: !flow.dispatch.tensor<readonly:tensor<?xf32>>{%0}
%else = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
: !flow.dispatch.tensor<readonly:tensor<?xf32>>{%0}
%result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
: !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%0}
%then_value = flow.dispatch.tensor.load %then, offsets = [%offset], sizes = [%size], strides = [1]
: !flow.dispatch.tensor<readonly:tensor<?xf32>>{%0} -> tensor<?xf32>
%else_value = flow.dispatch.tensor.load %else, offsets = [%offset], sizes = [%size], strides = [1]
: !flow.dispatch.tensor<readonly:tensor<?xf32>>{%0} -> tensor<?xf32>
%if = scf.if %cond -> (tensor<?xf32>) {
scf.yield %then_value : tensor<?xf32>
} else {
scf.yield %else_value : tensor<?xf32>
}
flow.dispatch.tensor.store %if, %result, offsets = [%result_offset], sizes = [%size], strides = [1]
: tensor<?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%0}
return
}
// CHECK-LABEL: func @if_conversion()
// CHECK-DAG: %[[S0:.+]] = hal.interface.constant.load[0]
// CHECK-DAG: %[[S1:.+]] = hal.interface.constant.load[2]
// CHECK-DAG: %[[COND:.+]] = hal.interface.constant.load[3]
// CHECK-DAG: %[[OFFSET:.+]] = hal.interface.constant.load[4]
// CHECK-DAG: %[[THEN_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
// CHECK-DAG: %[[ELSE_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
// CHECK-DAG: %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
// CHECK-DAG: %[[THEN:.+]] = flow.dispatch.tensor.load %[[THEN_BINDING]]
// CHECK-DAG: %[[ELSE:.+]] = flow.dispatch.tensor.load %[[ELSE_BINDING]]
// CHECK: scf.if %[[COND]] {
// CHECK-NEXT: flow.dispatch.tensor.store %[[THEN]], %[[RESULT_BINDING]]
// CHECK-SAME: offsets = [%[[OFFSET]]], sizes = [%[[S1]]]
// CHECK-SAME: flow.dispatch.tensor<writeonly:tensor<?xf32>>{%[[S0]]}
// CHECK-NEXT: } else {
// CHECK-NEXT: flow.dispatch.tensor.store %[[ELSE]], %[[RESULT_BINDING]]
// CHECK-SAME: offsets = [%[[OFFSET]]], sizes = [%[[S1]]]
// CHECK-SAME: flow.dispatch.tensor<writeonly:tensor<?xf32>>{%[[S0]]}
// CHECK-NEXT: }
// CHECK-NEXT: return

// -----

func.func @if_conversion_clone_offsets() {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%6, %7}
%11 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%8, %9}
%12 = affine.apply affine_map<()[s0, s1] -> (-s0 + s1 + (s0 ceildiv 16) * 16)>()[%6, %6]
%13 = affine.apply affine_map<()[s0, s1] -> (-s0 + s1 + (s0 ceildiv 16) * 16)>()[%7, %7]
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%14 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
%15 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
scf.for %arg0 = %14 to %12 step %15 {
%16 = affine.min affine_map<(d0)[s0, s1] -> (64, -d0 - s0 + s1 + (s0 ceildiv 16) * 16)>(%arg0)[%6, %6]
%17 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%18 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
scf.for %arg1 = %17 to %13 step %18 {
%19 = affine.min affine_map<(d0)[s0, s1] -> (64, -d0 - s0 + s1 + (s0 ceildiv 16) * 16)>(%arg1)[%7, %7]
%20 = affine.min affine_map<(d0)[s0] -> (s0, d0)>(%arg0)[%6]
%21 = affine.min affine_map<(d0, d1)[s0] -> (s0, d0 + d1)>(%arg0, %16)[%6]
%22 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%21, %20)
%23 = arith.cmpi eq, %22, %c0 : index
%24 = affine.apply affine_map<(d0, d1, d2) -> (d0 - d1 + d2)>(%16, %21, %20)
%25 = affine.min affine_map<(d0)[s0] -> (s0, d0)>(%arg1)[%7]
%26 = affine.min affine_map<(d0, d1)[s0] -> (s0, d0 + d1)>(%arg1, %19)[%7]
%27 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%26, %25)
%28 = arith.cmpi eq, %27, %c0 : index
%29 = arith.ori %28, %23 : i1
%30 = affine.apply affine_map<(d0, d1, d2) -> (d0 - d1 + d2)>(%19, %26, %25)
%31 = scf.if %29 -> (tensor<?x?xf32>) {
%generated = tensor.generate %16, %19 {
^bb0(%arg2: index, %arg3: index):
tensor.yield %cst : f32
} : tensor<?x?xf32>
scf.yield %generated : tensor<?x?xf32>
} else {
%34 = flow.dispatch.tensor.load %10, offsets = [%20, %25], sizes = [%22, %27], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%6, %7} -> tensor<?x?xf32>
%padded = tensor.pad %34 low[0, 0] high[%24, %30] {
^bb0(%arg2: index, %arg3: index):
tensor.yield %cst : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
scf.yield %padded : tensor<?x?xf32>
}
%32 = arith.index_castui %3 : i32 to index
%33 = arith.index_castui %4 : i32 to index
flow.dispatch.tensor.store %31, %11, offsets = [%arg0, %arg1], sizes = [%16, %19], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%32, %33}
}
}
return
}
// CHECK-LABEL: func @if_conversion_clone_offsets()
// CHECK: scf.if
// CHECK-NEXT: %[[GENERATED:.+]] = tensor.generate
// CHECK: flow.dispatch.tensor.store %[[GENERATED]]
// CHECK: else
// CHECK: %[[VAL:.+]] = flow.dispatch.tensor.load
// CHECK: %[[PADDED:.+]] = tensor.pad %[[VAL]]
// CHECK: flow.dispatch.tensor.store %[[PADDED]]

0 comments on commit b34880c

Please sign in to comment.