Skip to content

Commit

Permalink
[Do not submit] Use pad hack.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mahesh Ravishankar committed Apr 18, 2023
1 parent fde414d commit 8e6435e
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,9 @@ static SmallVector<Operation *> getAllFusableProducers(TilingInterface op) {
Operation *currOp = worklist.front();
worklist.pop_front();
for (OpOperand &operand : currOp->getOpOperands()) {
auto tilingInterfaceProducer =
operand.get().getDefiningOp<TilingInterface>();
if (!tilingInterfaceProducer ||
Operation *definingOp = operand.get().getDefiningOp();
auto tilingInterfaceProducer = dyn_cast<TilingInterface>(definingOp);
if (!tilingInterfaceProducer || isa<tensor::PadOp>(definingOp) ||
producers.count(tilingInterfaceProducer)) {
continue;
}
Expand Down
27 changes: 27 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ static void collectTiledAndFusedOps(Operation *rootOp,
}
}

static FailureOr<tensor::PadOp> foldIfGeneratedFromPadding(
RewriterBase &rewriter, tensor::PadOp untiledPadOp,
tensor::PadOp tiledPadOp) {
auto ifOp = dyn_cast<scf::IfOp>(tiledPadOp->getParentOp());
if (!ifOp) {
return failure();
};
Block *block = tiledPadOp->getBlock();
Operation *terminator = block->getTerminator();
ValueRange results = terminator->getOperands();
rewriter.inlineBlockBefore(block, ifOp, /*blockArgs=*/{});
rewriter.replaceOp(ifOp, results);
rewriter.eraseOp(terminator);
return tiledPadOp;
}

/// This pass starts with the last TilingInterface operation, tiles the op and
/// fuses its producers recursively. The `tilingLevel` must be specified. It
/// picks the `tilingLevel`-th list as tiling sizes from lowering_config.
Expand Down Expand Up @@ -83,6 +99,17 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp,
}
yieldedValuesToOrigValues.append(rootOp->result_begin(),
rootOp->result_end());

// WAR for `if` ops generating `scf.if` operations.
if (auto rootPadOp = dyn_cast<tensor::PadOp>(rootOp)) {
assert(tilingResult->tiledOps.size() == 1 &&
"expected tiling of `pad` op to return only one operation");
FailureOr<Operation *> replacementTiledOp = foldIfGeneratedFromPadding(
rewriter, rootPadOp, cast<tensor::PadOp>(tilingResult->tiledOps[0]));
if (!failed(replacementTiledOp)) {
tilingResult->tiledOps[0] = replacementTiledOp.value();
}
}
tiledOps.append(tilingResult->tiledOps);

// 2. Tiling each operation results in generation of slices. The source of
Expand Down
17 changes: 13 additions & 4 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ static void addTileAndDistributePasses(
createFoldAffineMinInDistributedLoopsPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
if (clEnablePadConsumerFusion && useFuseTensorPadWithConsumerPass) {
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
}
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());
nestedModulePM.addNestedPass<func::FuncOp>(
IREE::LinalgExt::createTileAndDecomposeAttentionPass());
nestedModulePM.addNestedPass<func::FuncOp>(
Expand Down Expand Up @@ -440,6 +440,10 @@ void addMultiTilingExpertPassPipeline(OpPassManager &passManager,

for (int64_t i = 1; i < numLevels - 1; ++i) {
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTileAndFusePass(i));
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());
}
// Run SplitReductionPass before the final reduction Fuse pass, because
// SplitReductionPass takes care of banked-tiling.
Expand Down Expand Up @@ -501,6 +505,11 @@ void addConvTileAndDecomposeExpertPassPipeline(OpPassManager &passManager,

nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTileAndFusePass(
static_cast<int64_t>(TilingLevel::ParallelTiles)));
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());

nestedModulePM.addNestedPass<func::FuncOp>(
createLLVMCPUTilePass(static_cast<int64_t>(TilingLevel::ReductionTiles)));
nestedModulePM.addNestedPass<func::FuncOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,23 @@ hal.executable private @pad_only {
// CHECK: %[[OUTPUT:.+]] = hal.interface.binding.subspan {{.+}} : memref<1x114x114x64xf32
// CHECK: scf.for
// CHECK: scf.for
// CHECK: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT]]
// CHECK: scf.for
// CHECK: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT]]
// CHECK: scf.for
// CHECK: scf.if
// CHECK: %[[OUTPUT_SUBVIEW_IF:.+]] = memref.subview %[[OUTPUT]]
// CHECK: linalg.generic
// CHECK-SAME: outs(%[[OUTPUT_SUBVIEW_IF]]
// CHECK: else
// CHECK: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT]]
// CHECK: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT]]
// CHECK: scf.for
// CHECK: scf.for
// CHECK: %[[OUTPUT_SLICE:.+]] = memref.subview %[[OUTPUT_SUBVIEW]]
// CHECK: %[[RESULT_VEC:.+]] = scf.if %{{.+}} -> (vector<4xf32>) {
// CHECK: %[[VEC_LOAD:.+]] = vector.load %[[INPUT_SUBVIEW]]
// CHECK: scf.yield %[[VEC_LOAD]]
// CHECK: }
// CHECK: vector.store %[[RESULT_VEC]], %[[OUTPUT_SLICE]]
// CHECK: scf.for
// CHECK: %[[OUTPUT_SLICE:.+]] = memref.subview %[[OUTPUT_SUBVIEW]]
// CHECK: %[[RESULT_VEC:.+]] = scf.if %{{.+}} -> (vector<4xf32>) {
// CHECK: %[[VEC_LOAD:.+]] = vector.load %[[INPUT_SUBVIEW]]
// CHECK: scf.yield %[[VEC_LOAD]]
// CHECK: }
// CHECK: vector.store %[[RESULT_VEC]], %[[OUTPUT_SLICE]]

// -----

Expand Down Expand Up @@ -117,36 +122,38 @@ hal.executable private @pad_with_producer {
// CHECK: %[[OUTPUT:.+]] = hal.interface.binding.subspan {{.+}} : memref<1x30x30x128xf32
// CHECK: scf.for
// CHECK: scf.for
// CHECK-DAG: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT]]
// CHECK-DAG: %[[FILTER_SUBVIEW:.+]] = memref.subview %[[FILTER]]
// CHECK-DAG: %[[BIAS_SUBVIEW:.+]] = memref.subview %[[BIAS]]
// CHECK-DAG: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT]]
// CHECK: scf.for
// CHECK: scf.if
// CHECK: else
// CHECK-DAG: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT]]
// CHECK-DAG: %[[FILTER_SUBVIEW:.+]] = memref.subview %[[FILTER]]
// CHECK-DAG: %[[BIAS_SUBVIEW:.+]] = memref.subview %[[BIAS]]
// CHECK-DAG: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT]]
// CHECK: scf.for
// CHECK-DAG: %[[INPUT_SLICE:.+]] = memref.subview %[[INPUT_SUBVIEW]]
// CHECK-DAG: %[[BIAS_ALLOC:.+]] = memref.alloca
// CHECK: scf.for
// CHECK: %[[FILTER_SLICE:.+]] = memref.subview %[[FILTER_SUBVIEW]]
// CHECK: %[[FILL_ALLOC:.+]] = memref.alloca
// CHECK: linalg.fill
// CHECK-SAME: outs(%[[FILL_ALLOC]]
// CHECK: %[[CONV_OUTPUT:.+]] = memref.subview %[[FILL_ALLOC]]
// CHECK: scf.for
// CHECK: %[[CONV_INPUT:.+]] = memref.subview %[[INPUT_SLICE]]
// CHECK: %[[CONV_FILTER:.+]] = memref.subview %[[FILTER_SLICE]]
// CHECK: linalg.conv_2d_nhwc_hwcf
// CHECK-SAME: ins(%[[CONV_INPUT]], %[[CONV_FILTER]] :
// CHECK-SAME: outs(%[[CONV_OUTPUT]] :
// CHECK: %[[BIAS_INPUT:.+]] = memref.subview %[[BIAS_SUBVIEW]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[CONV_OUTPUT]], %[[BIAS_INPUT]] :
// CHECK-SAME: outs(%[[BIAS_ALLOC]]
// CHECK: %[[OUTPUT_SLICE:.+]] = memref.subview %[[OUTPUT_SUBVIEW]]
// CHECK: linalg.fill ins(%{{.+}} : f32) outs(%[[OUTPUT_SLICE]]
// CHECK: %[[INTERIOR_SLICE:.+]] = memref.subview %[[OUTPUT_SLICE]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[BIAS_ALLOC]] :
// CHECK-SAME: outs(%[[INTERIOR_SLICE]] :
// CHECK-DAG: %[[INPUT_SLICE:.+]] = memref.subview %[[INPUT_SUBVIEW]]
// CHECK-DAG: %[[BIAS_ALLOC:.+]] = memref.alloca
// CHECK: scf.for
// CHECK: %[[FILTER_SLICE:.+]] = memref.subview %[[FILTER_SUBVIEW]]
// CHECK: %[[FILL_ALLOC:.+]] = memref.alloca
// CHECK: linalg.fill
// CHECK-SAME: outs(%[[FILL_ALLOC]]
// CHECK: %[[CONV_OUTPUT:.+]] = memref.subview %[[FILL_ALLOC]]
// CHECK: scf.for
// CHECK: %[[CONV_INPUT:.+]] = memref.subview %[[INPUT_SLICE]]
// CHECK: %[[CONV_FILTER:.+]] = memref.subview %[[FILTER_SLICE]]
// CHECK: linalg.conv_2d_nhwc_hwcf
// CHECK-SAME: ins(%[[CONV_INPUT]], %[[CONV_FILTER]] :
// CHECK-SAME: outs(%[[CONV_OUTPUT]] :
// CHECK: %[[BIAS_INPUT:.+]] = memref.subview %[[BIAS_SUBVIEW]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[CONV_OUTPUT]], %[[BIAS_INPUT]] :
// CHECK-SAME: outs(%[[BIAS_ALLOC]]
// CHECK: %[[OUTPUT_SLICE:.+]] = memref.subview %[[OUTPUT_SUBVIEW]]
// CHECK: linalg.fill ins(%{{.+}} : f32) outs(%[[OUTPUT_SLICE]]
// CHECK: %[[INTERIOR_SLICE:.+]] = memref.subview %[[OUTPUT_SLICE]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[BIAS_ALLOC]] :
// CHECK-SAME: outs(%[[INTERIOR_SLICE]] :

// -----

Expand Down Expand Up @@ -195,7 +202,6 @@ hal.executable private @pad_consumer_fusion {
// CHECK: %[[INPUT:.+]] = hal.interface.binding.subspan {{.+}} : memref<1x14x14x256xf32>
// CHECK: %[[FILTER:.+]] = hal.interface.binding.subspan {{.+}} : memref<3x3x256x256xf32>
// CHECK: %[[OUTPUT:.+]] = hal.interface.binding.subspan {{.+}} : memref<1x14x14x256xf32>
// CHECK-DAG: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT]]
// CHECK-DAG: %[[FILTER_SUBVIEW:.+]] = memref.subview %[[FILTER]]
// CHECK-DAG: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT]]
// CHECK: scf.for
Expand All @@ -205,7 +211,7 @@ hal.executable private @pad_consumer_fusion {
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK-COUNT-7: vector.load %[[INPUT_SUBVIEW]]
// CHECK-COUNT-7: vector.load %[[INPUT]]
// CHECK-COUNT-8: vector.load %[[FILTER_SUBVIEW]]
// CHECK-COUNT-8: vector.outerproduct
// CHECK: scf.yield
Expand Down
13 changes: 1 addition & 12 deletions compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,19 +496,8 @@ std::optional<LoopTilingAndDistributionInfo> isTiledAndDistributedLoop(
}

SmallVector<Operation *> getComputeOps(func::FuncOp funcOp) {
Block *body = &funcOp.getFunctionBody().front();
auto forOps = body->getOps<scf::ForOp>();
while (!forOps.empty()) {
assert(llvm::hasSingleElement(forOps) &&
"expected dispatch function with single block");
scf::ForOp forOp = *(forOps.begin());
body = forOp.getBody();
forOps = body->getOps<scf::ForOp>();
}
SmallVector<Operation *> computeOps;
for (auto op : body->getOps<TilingInterface>()) {
computeOps.push_back(op);
}
funcOp.walk([&](TilingInterface op) { computeOps.push_back(op); });
return computeOps;
}

Expand Down

0 comments on commit 8e6435e

Please sign in to comment.