diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp index d834530c31ac2..1caf8e2e6fa43 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp @@ -436,9 +436,9 @@ static SmallVector getAllFusableProducers(TilingInterface op) { Operation *currOp = worklist.front(); worklist.pop_front(); for (OpOperand &operand : currOp->getOpOperands()) { - auto tilingInterfaceProducer = - operand.get().getDefiningOp(); - if (!tilingInterfaceProducer || + Operation *definingOp = operand.get().getDefiningOp(); + auto tilingInterfaceProducer = dyn_cast(definingOp); + if (!tilingInterfaceProducer || isa(definingOp) || producers.count(tilingInterfaceProducer)) { continue; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp index 787fde77188a3..bf1a79d9d873f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp @@ -45,6 +45,22 @@ static void collectTiledAndFusedOps(Operation *rootOp, } } +static FailureOr foldIfGeneratedFromPadding( + RewriterBase &rewriter, tensor::PadOp untiledPadOp, + tensor::PadOp tiledPadOp) { + auto ifOp = dyn_cast(tiledPadOp->getParentOp()); + if (!ifOp) { + return failure(); + }; + Block *block = tiledPadOp->getBlock(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.inlineBlockBefore(block, ifOp, /*blockArgs=*/{}); + rewriter.replaceOp(ifOp, results); + rewriter.eraseOp(terminator); + return tiledPadOp; +} + /// This pass starts with the last TilingInterface operation, tiles the op and /// fuses its producers recursively. The `tilingLevel` must be specified. It /// picks the `tilingLevel`-th list as tiling sizes from lowering_config. @@ -83,6 +99,17 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp, } yieldedValuesToOrigValues.append(rootOp->result_begin(), rootOp->result_end()); + + // WAR for `if` ops generating `scf.if` operations. + if (auto rootPadOp = dyn_cast(rootOp)) { + assert(tilingResult->tiledOps.size() == 1 && + "expected tiling of `pad` op to return only one operation"); + FailureOr replacementTiledOp = foldIfGeneratedFromPadding( + rewriter, rootPadOp, cast(tilingResult->tiledOps[0])); + if (!failed(replacementTiledOp)) { + tilingResult->tiledOps[0] = replacementTiledOp.value(); + } + } tiledOps.append(tilingResult->tiledOps); // 2. Tiling each operation results in generation of slices. The source of diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 719b31becb536..6c3b986cbfd1a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -134,10 +134,10 @@ static void addTileAndDistributePasses( createFoldAffineMinInDistributedLoopsPass()); nestedModulePM.addPass(createCanonicalizerPass()); nestedModulePM.addPass(createCSEPass()); - if (clEnablePadConsumerFusion && useFuseTensorPadWithConsumerPass) { - nestedModulePM.addNestedPass( - createFuseTensorPadWithConsumerPass()); - } + nestedModulePM.addNestedPass( + createFuseTensorPadWithConsumerPass()); + nestedModulePM.addNestedPass( + createConcretizePadResultShapePass()); nestedModulePM.addNestedPass( IREE::LinalgExt::createTileAndDecomposeAttentionPass()); nestedModulePM.addNestedPass( @@ -440,6 +440,10 @@ void addMultiTilingExpertPassPipeline(OpPassManager &passManager, for (int64_t i = 1; i < numLevels - 1; ++i) { nestedModulePM.addNestedPass(createLLVMCPUTileAndFusePass(i)); + nestedModulePM.addNestedPass( + createFuseTensorPadWithConsumerPass()); + nestedModulePM.addNestedPass( + createConcretizePadResultShapePass()); } // Run SplitReductionPass before the final reduction Fuse pass, because // SplitReductionPass takes care of banked-tiling. @@ -501,6 +505,11 @@ void addConvTileAndDecomposeExpertPassPipeline(OpPassManager &passManager, nestedModulePM.addNestedPass(createLLVMCPUTileAndFusePass( static_cast(TilingLevel::ParallelTiles))); + nestedModulePM.addNestedPass( + createFuseTensorPadWithConsumerPass()); + nestedModulePM.addNestedPass( + createConcretizePadResultShapePass()); + nestedModulePM.addNestedPass( createLLVMCPUTilePass(static_cast(TilingLevel::ReductionTiles))); nestedModulePM.addNestedPass( diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pad_pipeline_tests.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pad_pipeline_tests.mlir index 806cf7b3af470..3a90248e3b762 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pad_pipeline_tests.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pad_pipeline_tests.mlir @@ -39,18 +39,23 @@ hal.executable private @pad_only { // CHECK: %[[OUTPUT:.+]] = hal.interface.binding.subspan {{.+}} : memref<1x114x114x64xf32 // CHECK: scf.for // CHECK: scf.for -// CHECK: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT]] // CHECK: scf.for -// CHECK: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT]] -// CHECK: scf.for +// CHECK: scf.if +// CHECK: %[[OUTPUT_SUBVIEW_IF:.+]] = memref.subview %[[OUTPUT]] +// CHECK: linalg.generic +// CHECK-SAME: outs(%[[OUTPUT_SUBVIEW_IF]] +// CHECK: else +// CHECK: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT]] +// CHECK: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT]] // CHECK: scf.for // CHECK: scf.for -// CHECK: %[[OUTPUT_SLICE:.+]] = memref.subview %[[OUTPUT_SUBVIEW]] -// CHECK: %[[RESULT_VEC:.+]] = scf.if %{{.+}} -> (vector<4xf32>) { -// CHECK: %[[VEC_LOAD:.+]] = vector.load %[[INPUT_SUBVIEW]] -// CHECK: scf.yield %[[VEC_LOAD]] -// CHECK: } -// CHECK: vector.store %[[RESULT_VEC]], %[[OUTPUT_SLICE]] +// CHECK: scf.for +// CHECK: %[[OUTPUT_SLICE:.+]] = memref.subview %[[OUTPUT_SUBVIEW]] +// CHECK: %[[RESULT_VEC:.+]] = scf.if %{{.+}} -> (vector<4xf32>) { +// CHECK: %[[VEC_LOAD:.+]] = vector.load %[[INPUT_SUBVIEW]] +// CHECK: scf.yield %[[VEC_LOAD]] +// CHECK: } +// CHECK: vector.store %[[RESULT_VEC]], %[[OUTPUT_SLICE]] // ----- @@ -117,36 +122,38 @@ hal.executable private @pad_with_producer { // CHECK: %[[OUTPUT:.+]] = hal.interface.binding.subspan {{.+}} : memref<1x30x30x128xf32 // CHECK: scf.for // CHECK: scf.for -// CHECK-DAG: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT]] -// CHECK-DAG: %[[FILTER_SUBVIEW:.+]] = memref.subview %[[FILTER]] -// CHECK-DAG: %[[BIAS_SUBVIEW:.+]] = memref.subview %[[BIAS]] -// CHECK-DAG: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT]] -// CHECK: scf.for +// CHECK: scf.if +// CHECK: else +// CHECK-DAG: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT]] +// CHECK-DAG: %[[FILTER_SUBVIEW:.+]] = memref.subview %[[FILTER]] +// CHECK-DAG: %[[BIAS_SUBVIEW:.+]] = memref.subview %[[BIAS]] +// CHECK-DAG: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT]] // CHECK: scf.for -// CHECK-DAG: %[[INPUT_SLICE:.+]] = memref.subview %[[INPUT_SUBVIEW]] -// CHECK-DAG: %[[BIAS_ALLOC:.+]] = memref.alloca // CHECK: scf.for -// CHECK: %[[FILTER_SLICE:.+]] = memref.subview %[[FILTER_SUBVIEW]] -// CHECK: %[[FILL_ALLOC:.+]] = memref.alloca -// CHECK: linalg.fill -// CHECK-SAME: outs(%[[FILL_ALLOC]] -// CHECK: %[[CONV_OUTPUT:.+]] = memref.subview %[[FILL_ALLOC]] -// CHECK: scf.for -// CHECK: %[[CONV_INPUT:.+]] = memref.subview %[[INPUT_SLICE]] -// CHECK: %[[CONV_FILTER:.+]] = memref.subview %[[FILTER_SLICE]] -// CHECK: linalg.conv_2d_nhwc_hwcf -// CHECK-SAME: ins(%[[CONV_INPUT]], %[[CONV_FILTER]] : -// CHECK-SAME: outs(%[[CONV_OUTPUT]] : -// CHECK: %[[BIAS_INPUT:.+]] = memref.subview %[[BIAS_SUBVIEW]] -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[CONV_OUTPUT]], %[[BIAS_INPUT]] : -// CHECK-SAME: outs(%[[BIAS_ALLOC]] -// CHECK: %[[OUTPUT_SLICE:.+]] = memref.subview %[[OUTPUT_SUBVIEW]] -// CHECK: linalg.fill ins(%{{.+}} : f32) outs(%[[OUTPUT_SLICE]] -// CHECK: %[[INTERIOR_SLICE:.+]] = memref.subview %[[OUTPUT_SLICE]] -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[BIAS_ALLOC]] : -// CHECK-SAME: outs(%[[INTERIOR_SLICE]] : +// CHECK-DAG: %[[INPUT_SLICE:.+]] = memref.subview %[[INPUT_SUBVIEW]] +// CHECK-DAG: %[[BIAS_ALLOC:.+]] = memref.alloca +// CHECK: scf.for +// CHECK: %[[FILTER_SLICE:.+]] = memref.subview %[[FILTER_SUBVIEW]] +// CHECK: %[[FILL_ALLOC:.+]] = memref.alloca +// CHECK: linalg.fill +// CHECK-SAME: outs(%[[FILL_ALLOC]] +// CHECK: %[[CONV_OUTPUT:.+]] = memref.subview %[[FILL_ALLOC]] +// CHECK: scf.for +// CHECK: %[[CONV_INPUT:.+]] = memref.subview %[[INPUT_SLICE]] +// CHECK: %[[CONV_FILTER:.+]] = memref.subview %[[FILTER_SLICE]] +// CHECK: linalg.conv_2d_nhwc_hwcf +// CHECK-SAME: ins(%[[CONV_INPUT]], %[[CONV_FILTER]] : +// CHECK-SAME: outs(%[[CONV_OUTPUT]] : +// CHECK: %[[BIAS_INPUT:.+]] = memref.subview %[[BIAS_SUBVIEW]] +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[CONV_OUTPUT]], %[[BIAS_INPUT]] : +// CHECK-SAME: outs(%[[BIAS_ALLOC]] +// CHECK: %[[OUTPUT_SLICE:.+]] = memref.subview %[[OUTPUT_SUBVIEW]] +// CHECK: linalg.fill ins(%{{.+}} : f32) outs(%[[OUTPUT_SLICE]] +// CHECK: %[[INTERIOR_SLICE:.+]] = memref.subview %[[OUTPUT_SLICE]] +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[BIAS_ALLOC]] : +// CHECK-SAME: outs(%[[INTERIOR_SLICE]] : // ----- @@ -195,7 +202,6 @@ hal.executable private @pad_consumer_fusion { // CHECK: %[[INPUT:.+]] = hal.interface.binding.subspan {{.+}} : memref<1x14x14x256xf32> // CHECK: %[[FILTER:.+]] = hal.interface.binding.subspan {{.+}} : memref<3x3x256x256xf32> // CHECK: %[[OUTPUT:.+]] = hal.interface.binding.subspan {{.+}} : memref<1x14x14x256xf32> -// CHECK-DAG: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT]] // CHECK-DAG: %[[FILTER_SUBVIEW:.+]] = memref.subview %[[FILTER]] // CHECK-DAG: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT]] // CHECK: scf.for @@ -205,7 +211,7 @@ hal.executable private @pad_consumer_fusion { // CHECK: scf.for // CHECK: scf.for // CHECK: scf.for -// CHECK-COUNT-7: vector.load %[[INPUT_SUBVIEW]] +// CHECK-COUNT-7: vector.load %[[INPUT]] // CHECK-COUNT-8: vector.load %[[FILTER_SUBVIEW]] // CHECK-COUNT-8: vector.outerproduct // CHECK: scf.yield diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp index 4bca1ec5fd885..d4fd56ff3f00f 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp @@ -496,19 +496,8 @@ std::optional isTiledAndDistributedLoop( } SmallVector getComputeOps(func::FuncOp funcOp) { - Block *body = &funcOp.getFunctionBody().front(); - auto forOps = body->getOps(); - while (!forOps.empty()) { - assert(llvm::hasSingleElement(forOps) && - "expected dispatch function with single block"); - scf::ForOp forOp = *(forOps.begin()); - body = forOp.getBody(); - forOps = body->getOps(); - } SmallVector computeOps; - for (auto op : body->getOps()) { - computeOps.push_back(op); - } + funcOp.walk([&](TilingInterface op) { computeOps.push_back(op); }); return computeOps; }