Skip to content

Commit

Permalink
Pad handling without changing upstream interface. (iree-org#13133)
Browse files Browse the repository at this point in the history
The current default dispatch region formation has options to

- disable splitting pad into fill + tensor.insert_slice
- allow fusion of pad with producer
- allow fusion of pad with consumer.
While none of these are on by default, this PR adds support for handling these in the CPU backend. The current state is

- The pad by itself in a dispatch gets vectorized.
- Pad fused with consumer gets vectorized too
- Pad fused with producer does not get vectorized. This requries more work and potentially some changes to get the IR into a better state w.r.t destination passing.

There is lit test that show the handling of the different modes today within the CPU backend. To get things working, one thing to handle is the code-generated by tiling the pad operation is of the form

```
scf.if {
  ...
} else {
  ... tensor.pad 
}
```

the if here is to account for cases where a tile could be reading only the padding. This does not happen in IREE, so there is a temporary hack here that just folds the if away. Long term a better solution is needed (probably requiring rethinking of pad specification and tiling).
  • Loading branch information
MaheshRavishankar authored and NatashaKnk committed Jul 6, 2023
1 parent ab0d0bc commit 6f72763
Show file tree
Hide file tree
Showing 11 changed files with 410 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,19 @@ LogicalResult eliminateEmptyTensors(

void EliminateEmptyTensorsPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext *context = &getContext();

// Run the convert to destination style patterns.
{
RewritePatternSet patterns(context);
linalg::populateConvertToDestinationStylePatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
moduleOp->emitOpError(
"Failed in conversion to destination style patterns");
return signalPassFailure();
}
}

OneShotBufferizationOptions options = getBufferizationOptions();

IRRewriter rewriter(moduleOp->getContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,9 @@ static SmallVector<Operation *> getAllFusableProducers(TilingInterface op) {
Operation *currOp = worklist.front();
worklist.pop_front();
for (OpOperand &operand : currOp->getOpOperands()) {
auto tilingInterfaceProducer =
operand.get().getDefiningOp<TilingInterface>();
if (!tilingInterfaceProducer ||
Operation *definingOp = operand.get().getDefiningOp();
auto tilingInterfaceProducer = dyn_cast<TilingInterface>(definingOp);
if (!tilingInterfaceProducer || isa<tensor::PadOp>(definingOp) ||
producers.count(tilingInterfaceProducer)) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ void registerPartitionableLoopsInterfaceModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
tensor::PackOp::attachInterface<
OuterParallelAsPartitionableLoops<tensor::PackOp>>(*ctx);
tensor::PadOp::attachInterface<
OuterParallelAsPartitionableLoops<tensor::PadOp>>(*ctx);
tensor::UnPackOp::attachInterface<
OuterParallelAsPartitionableLoops<tensor::UnPackOp>>(*ctx);
});
Expand Down
111 changes: 97 additions & 14 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,22 @@ static llvm::raw_ostream &operator<<(
return os;
}

/// Splits the given `Range` vector and returns the `lbs` and the `ubs` as
/// separate lists.
static void getBoundsFromRange(ArrayRef<Range> loopRange,
SmallVector<int64_t> &lb,
SmallVector<int64_t> &ub) {
auto getStaticValue = [](OpFoldResult ofr) -> int64_t {
Optional<int64_t> intVal = getConstantIntValue(ofr);
if (!intVal) return ShapedType::kDynamic;
return intVal.value();
};
lb = llvm::to_vector(llvm::map_range(
loopRange, [&](Range r) { return getStaticValue(r.offset); }));
ub = llvm::to_vector(llvm::map_range(
loopRange, [&](Range r) { return getStaticValue(r.size); }));
}

/// Returns true if all the input and output tensor operands of 'op' are fully
/// dynamic.
static bool isFullyDynamicOp(linalg::LinalgOp op) {
Expand Down Expand Up @@ -1751,6 +1767,46 @@ static LogicalResult setRootConfig(func::FuncOp entryPointFn,
return setConvRootConfig(entryPointFn, convOp, targetTileSizes, vectorSize);
}

static LogicalResult setRootConfig(func::FuncOp entryPointFn,
tensor::PadOp padOp) {
OpBuilder builder(padOp.getContext());
builder.setInsertionPoint(padOp);
SmallVector<Range> iterationDomain =
cast<TilingInterface>(padOp.getOperation()).getIterationDomain(builder);
SmallVector<int64_t> lbs, ubs;
getBoundsFromRange(iterationDomain, lbs, ubs);

SmallVector<int64_t> minTileSizes(lbs.size(), 1);
SmallVector<int64_t> maxTileSizes(ubs.size(), defaultWorkgroupTileSize);
SmallVector<int64_t> vectorTileSizes(lbs.size(), 1);

unsigned typeWidthInBytes = IREE::Util::getRoundedElementByteWidth(
padOp.getResultType().getElementType());
int64_t typeVectorSize = getVectorSize(entryPointFn, typeWidthInBytes);
vectorTileSizes.back() = (ubs.back() == ShapedType::kDynamic
? 1
: std::min(typeVectorSize, ubs.back()));
minTileSizes.back() = vectorTileSizes.back();

SmallVector<unsigned> partitionableLoops =
cast<PartitionableLoopsInterface>(padOp.getOperation())
.getPartitionableLoops(kNumMaxParallelDims);
SmallVector<int64_t> distributedTileSizes =
getDefaultDistributedLevelTileSizes(partitionableLoops, lbs, ubs,
minTileSizes, maxTileSizes);
TileSizesListType tileSizes;
// Distribution tiling
tileSizes.emplace_back(std::move(distributedTileSizes));
// Tiling for vectorization.
tileSizes.emplace_back(std::move(vectorTileSizes));
// No further tiling.
tileSizes.push_back({});

return setOpConfigAndEntryPointFnTranslation(
entryPointFn, padOp, tileSizes,
DispatchLoweringPassPipeline::CPUDoubleTilingExpert);
}

/// Set default configuration for Linalg ops.
static LogicalResult setRootConfig(
func::FuncOp entryPointFn, linalg::LinalgOp linalgOp,
Expand Down Expand Up @@ -1812,12 +1868,13 @@ static LogicalResult setRootConfigImpl(
return setRootConfig(entryPointFn, op, LinalgOpInfo(op),
targetMLTransInfo);
})
.Case<IREE::LinalgExt::FftOp, tensor::PackOp, linalg::Mmt4DOp,
linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNchwFchwOp,
linalg::PoolingNhwcSumOp, linalg::PoolingNhwcMaxOp,
linalg::PoolingNhwcMaxUnsignedOp, linalg::PoolingNhwcMinOp,
linalg::PoolingNhwcMinUnsignedOp, linalg::PoolingNchwSumOp,
linalg::PoolingNchwMaxOp, linalg::DepthwiseConv2DNhwcHwcOp>(
.Case<IREE::LinalgExt::FftOp, tensor::PackOp, tensor::PadOp,
linalg::Mmt4DOp, linalg::Conv2DNhwcHwcfOp,
linalg::Conv2DNchwFchwOp, linalg::PoolingNhwcSumOp,
linalg::PoolingNhwcMaxOp, linalg::PoolingNhwcMaxUnsignedOp,
linalg::PoolingNhwcMinOp, linalg::PoolingNhwcMinUnsignedOp,
linalg::PoolingNchwSumOp, linalg::PoolingNchwMaxOp,
linalg::DepthwiseConv2DNhwcHwcOp>(
[&](auto op) { return setRootConfig(entryPointFn, op); })
.Case<tensor::UnPackOp>(
[&](auto op) { return setUnPackOpRootConfig(entryPointFn, op); })
Expand Down Expand Up @@ -1867,21 +1924,47 @@ static LogicalResult setVMVXRootConfigImpl(func::FuncOp entryPointFn,
/// to the end of the function is the root op.
static FailureOr<Operation *> getRootOperation(
ArrayRef<Operation *> computeOps) {
Operation *rootOperation = nullptr;
for (auto op : llvm::reverse(computeOps)) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp) continue;
if (linalgOp.getNumReductionLoops()) return op;
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
// Do not treat linalg ops that are all parallel as root operations in
// this sweep.
if (linalgOp.getNumLoops() == linalgOp.getNumParallelLoops()) continue;

// All other linalg ops are root ops.
rootOperation = op;
break;
}

if (isa<TilingInterface>(op) &&
!isa<tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(op)) {
// All other operations that implement this interface are root ops.
rootOperation = op;
break;
}
}

for (auto op : llvm::reverse(computeOps)) {
if (isa<linalg::LinalgOp, IREE::LinalgExt::LinalgExtOp>(op)) return op;
if (!rootOperation) {
// Check for elementwise operations.
for (auto op : llvm::reverse(computeOps)) {
if (isa<linalg::LinalgOp>(op)) {
rootOperation = op;
break;
}
}
}

for (auto op : llvm::reverse(computeOps)) {
if (isa<TilingInterface>(op)) return op;
if (!rootOperation) {
// Check for pad/pack/unpack ops by themselves.
for (auto op : llvm::reverse(computeOps)) {
if (isa<tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(op)) {
rootOperation = op;
break;
}
}
}

return nullptr;
return rootOperation;
}

static LogicalResult adjustTileSizesForPackOp(func::FuncOp entryPointFn,
Expand Down
40 changes: 40 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,35 @@ static void collectTiledAndFusedOps(Operation *rootOp,
}
}

/// Tiling of `tensor.pad` operation generates
///
/// ```mlir
/// scf.if {
/// ...
/// } else {
/// tensor.pad
/// }
/// ```
///
/// For IREEs use case we dont need this. So this folds away the `if` condition.
/// Note this is a fairly hacky workaround, but the current pad operation
/// semantics force us down this path.
static FailureOr<tensor::PadOp> foldIfGeneratedFromPadding(
RewriterBase &rewriter, tensor::PadOp untiledPadOp,
tensor::PadOp tiledPadOp) {
auto ifOp = dyn_cast<scf::IfOp>(tiledPadOp->getParentOp());
if (!ifOp) {
return failure();
};
Block *block = tiledPadOp->getBlock();
Operation *terminator = block->getTerminator();
ValueRange results = terminator->getOperands();
rewriter.inlineBlockBefore(block, ifOp, /*blockArgs=*/{});
rewriter.replaceOp(ifOp, results);
rewriter.eraseOp(terminator);
return tiledPadOp;
}

/// This pass starts with the last TilingInterface operation, tiles the op and
/// fuses its producers recursively. The `tilingLevel` must be specified. It
/// picks the `tilingLevel`-th list as tiling sizes from lowering_config.
Expand Down Expand Up @@ -83,6 +112,17 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp,
}
yieldedValuesToOrigValues.append(rootOp->result_begin(),
rootOp->result_end());

// WAR for `if` ops generating `scf.if` operations.
if (auto rootPadOp = dyn_cast<tensor::PadOp>(rootOp)) {
assert(tilingResult->tiledOps.size() == 1 &&
"expected tiling of `pad` op to return only one operation");
FailureOr<Operation *> replacementTiledOp = foldIfGeneratedFromPadding(
rewriter, rootPadOp, cast<tensor::PadOp>(tilingResult->tiledOps[0]));
if (!failed(replacementTiledOp)) {
tilingResult->tiledOps[0] = replacementTiledOp.value();
}
}
tiledOps.append(tilingResult->tiledOps);

// 2. Tiling each operation results in generation of slices. The source of
Expand Down
59 changes: 28 additions & 31 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ static void addBufferizePasses(OpPassManager &passManager) {
createEraseHALDescriptorTypeFromMemRefPass());
}

static void addTileAndDistributePasses(
OpPassManager &pm, bool useFuseTensorPadWithConsumerPass = true) {
static void addTileAndDistributePasses(OpPassManager &pm) {
pm.addPass(createTileAndDistributeToWorkgroupsPass());
auto &nestedModulePM = pm.nest<ModuleOp>();
nestedModulePM.addNestedPass<func::FuncOp>(
Expand All @@ -135,10 +134,10 @@ static void addTileAndDistributePasses(
createFoldAffineMinInDistributedLoopsPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
if (clEnablePadConsumerFusion && useFuseTensorPadWithConsumerPass) {
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
}
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());
nestedModulePM.addNestedPass<func::FuncOp>(
IREE::LinalgExt::createTileAndDecomposeAttentionPass());
nestedModulePM.addNestedPass<func::FuncOp>(
Expand Down Expand Up @@ -353,8 +352,7 @@ void addCPUBufferOpsTileAndVectorizePipeline(OpPassManager &passManager,

void addDoubleTilingPadExpertPassPipeline(OpPassManager &passManager,
bool enableVectorMasking) {
addTileAndDistributePasses(passManager,
/*useFuseTensorPadWithConsumerPass=*/false);
addTileAndDistributePasses(passManager);

OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTileAndFusePass(
Expand Down Expand Up @@ -394,8 +392,7 @@ void addDoubleTilingPadExpertPassPipeline(OpPassManager &passManager,

void addVMVXDefaultPassPipeline(OpPassManager &passManager,
bool enableMicrokernels) {
addTileAndDistributePasses(passManager,
/*useFuseTensorPadWithConsumerPass=*/false);
addTileAndDistributePasses(passManager);

if (enableMicrokernels) {
passManager.nest<ModuleOp>().addPass(createLLVMCPULowerToUKernelsPass());
Expand Down Expand Up @@ -441,6 +438,10 @@ void addMultiTilingExpertPassPipeline(OpPassManager &passManager,

for (int64_t i = 1; i < numLevels - 1; ++i) {
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTileAndFusePass(i));
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());
}
// Run SplitReductionPass before the final reduction Fuse pass, because
// SplitReductionPass takes care of banked-tiling.
Expand All @@ -449,13 +450,10 @@ void addMultiTilingExpertPassPipeline(OpPassManager &passManager,
nestedModulePM.addNestedPass<func::FuncOp>(
createLLVMCPUTilePass(numLevels - 1));

if (clEnablePadConsumerFusion) {
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());
nestedModulePM.addNestedPass<func::FuncOp>(createVectorizePadPass());
}
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());

if (enablePeeling) {
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUPeelPass());
Expand All @@ -466,6 +464,8 @@ void addMultiTilingExpertPassPipeline(OpPassManager &passManager,
createDecomposePackUnPackOpsPass());
nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());

nestedModulePM.addNestedPass<func::FuncOp>(createVectorizePadPass());
LLVMCPUVectorizationPassOptions options;
options.enableVectorMasking = enableVectorMasking;
// TODO(#13036): Re-enable once debugged.
Expand Down Expand Up @@ -503,26 +503,23 @@ void addConvTileAndDecomposeExpertPassPipeline(OpPassManager &passManager,

nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTileAndFusePass(
static_cast<int64_t>(TilingLevel::ParallelTiles)));
if (clEnablePadConsumerFusion) {
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());
}
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());

nestedModulePM.addNestedPass<func::FuncOp>(
createLLVMCPUTilePass(static_cast<int64_t>(TilingLevel::ReductionTiles)));
nestedModulePM.addNestedPass<func::FuncOp>(
createDecomposeConvolutionToLowerDimOpsPass());

if (clEnablePadConsumerFusion) {
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());
nestedModulePM.addNestedPass<func::FuncOp>(createVectorizePadPass());
}
nestedModulePM.addNestedPass<func::FuncOp>(
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());

{
nestedModulePM.addNestedPass<func::FuncOp>(createVectorizePadPass());
LLVMCPUVectorizationPassOptions options;
options.enableVectorMasking = enableVectorMasking;
options.vectorizePadding = true;
Expand All @@ -537,7 +534,7 @@ void addConvTileAndDecomposeExpertPassPipeline(OpPassManager &passManager,
nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());
nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createOptimizeVectorTransferPass(/*flatten=*/false));
createOptimizeVectorTransferPass(/*flatten=*/true));
addBufferizePasses(nestedModulePM);

// Run IREE specific passes before vector lowering expert.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ iree_lit_test_suite(
"materialize_vmvx_launch_configuration.mlir",
"materialize_x86_64_launch_configuration.mlir",
"pad_conv_pipeline_tests.mlir",
"pad_pipeline_tests.mlir",
"peel.mlir",
"peel_and_vectorize.mlir",
"pipeline_tests.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ iree_lit_test_suite(
"materialize_vmvx_launch_configuration.mlir"
"materialize_x86_64_launch_configuration.mlir"
"pad_conv_pipeline_tests.mlir"
"pad_pipeline_tests.mlir"
"peel.mlir"
"peel_and_vectorize.mlir"
"pipeline_tests.mlir"
Expand Down
Loading

0 comments on commit 6f72763

Please sign in to comment.