diff --git a/mlir/include/mlir/Transforms/Inliner.h b/mlir/include/mlir/Transforms/Inliner.h index 1fe61fb4bbe7d9..073b83f6f844c5 100644 --- a/mlir/include/mlir/Transforms/Inliner.h +++ b/mlir/include/mlir/Transforms/Inliner.h @@ -69,19 +69,6 @@ class InlinerConfig { /// of inlining decisions from the leafs to the roots of the callgraph. class Inliner { public: - using RunPipelineHelperTy = std::function; - - Inliner(Operation *op, CallGraph &cg, Pass &pass, AnalysisManager am, - RunPipelineHelperTy runPipelineHelper, const InlinerConfig &config) - : op(op), cg(cg), pass(pass), am(am), - runPipelineHelper(std::move(runPipelineHelper)), config(config) {} - Inliner(Inliner &) = delete; - void operator=(const Inliner &) = delete; - - /// Perform inlining on a OpTrait::SymbolTable operation. - LogicalResult doInlining(); - /// This struct represents a resolved call to a given callgraph node. Given /// that the call does not actually contain a direct reference to the /// Region(CallGraphNode) that it is dispatching to, we need to resolve them @@ -94,7 +81,29 @@ class Inliner { CallGraphNode *sourceNode, *targetNode; }; -protected: + using RunPipelineHelperTy = std::function; + + /// Type of the callback answering if it is profitable + /// to inline a callable operation at a call site. + /// It might be the case that the ResolvedCall does not provide + /// enough context to make the profitability decision, so + /// this hook's interface might need to be extended in future. + using ProfitabilityCallbackTy = std::function; + + Inliner(Operation *op, CallGraph &cg, Pass &pass, AnalysisManager am, + RunPipelineHelperTy runPipelineHelper, const InlinerConfig &config, + ProfitabilityCallbackTy isProfitableToInline) + : op(op), cg(cg), pass(pass), am(am), + runPipelineHelper(std::move(runPipelineHelper)), config(config), + isProfitableToInline(std::move(isProfitableToInline)) {} + Inliner(Inliner &) = delete; + void operator=(const Inliner &) = delete; + + /// Perform inlining on a OpTrait::SymbolTable operation. + LogicalResult doInlining(); + +private: /// An OpTrait::SymbolTable operation to run the inlining on. Operation *op; /// A CallGraph analysis for the given operation. @@ -108,12 +117,12 @@ class Inliner { const RunPipelineHelperTy runPipelineHelper; /// The inliner configuration parameters. const InlinerConfig &config; + /// Returns true, if it is profitable to inline the callable operation + /// at the call site. + ProfitabilityCallbackTy isProfitableToInline; -private: /// Forward declaration of the class providing the actual implementation. class Impl; - -public: }; } // namespace mlir diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index b8fdf7a580476e..51b2a27da639d6 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -278,6 +278,13 @@ def Inliner : Pass<"inline"> { Option<"maxInliningIterations", "max-iterations", "unsigned", /*default=*/"4", "Maximum number of iterations when inlining within an SCC">, + Option<"inliningThreshold", "inlining-threshold", "unsigned", + /*default=*/"-1U", + "If the ratio between the number of the operations " + "in the callee and the number of the operations " + "in the caller exceeds this value (in percentage), " + "then the callee is not inlined even if it is legal " + "to inline it">, ]; } diff --git a/mlir/lib/Transforms/InlinerPass.cpp b/mlir/lib/Transforms/InlinerPass.cpp index c058e8050cd199..08d8dbf73a6a1d 100644 --- a/mlir/lib/Transforms/InlinerPass.cpp +++ b/mlir/lib/Transforms/InlinerPass.cpp @@ -24,6 +24,8 @@ namespace mlir { #include "mlir/Transforms/Passes.h.inc" } // namespace mlir +#define DEBUG_TYPE "inliner-pass" + using namespace mlir; /// This function implements the inliner optimization pipeline. @@ -88,6 +90,35 @@ InlinerPass::InlinerPass(std::function defaultPipeline, config.setOpPipelines(std::move(opPipelines)); } +// Return true if the inlining ratio does not exceed the threshold. +static bool isProfitableToInline(const Inliner::ResolvedCall &resolvedCall, + unsigned inliningThreshold) { + Region *callerRegion = resolvedCall.sourceNode->getCallableRegion(); + Region *calleeRegion = resolvedCall.targetNode->getCallableRegion(); + + // We should not get external nodes here, but just return true + // for now to preserve the original behavior of the inliner pass. + if (!calleeRegion || !calleeRegion) + return true; + + auto countOps = [](Region *region) { + unsigned count = 0; + region->walk([&](Operation *) { ++count; }); + return count; + }; + + unsigned callerOps = countOps(callerRegion); + + // Always inline empty callees (if it is possible at all). + if (callerOps == 0) + return true; + + unsigned ratio = countOps(calleeRegion) * 100 / callerOps; + LLVM_DEBUG(llvm::dbgs() << "Callee / caller operation ratio (max: " + << inliningThreshold << "%): " << ratio << "%\n"); + return ratio <= inliningThreshold; +} + void InlinerPass::runOnOperation() { CallGraph &cg = getAnalysis(); @@ -100,9 +131,14 @@ void InlinerPass::runOnOperation() { return signalPassFailure(); } + // By default, assume that any inlining is profitable. + auto profitabilityCb = [=](const Inliner::ResolvedCall &call) { + return isProfitableToInline(call, inliningThreshold); + }; + // Get an instance of the inliner. Inliner inliner(op, cg, *this, getAnalysisManager(), runPipelineHelper, - config); + config, profitabilityCb); // Run the inlining. if (failed(inliner.doInlining())) diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index 74776a73db9aaa..c24eff7353f6b9 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -733,6 +733,9 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) { if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks()) return false; + if (!inliner.isProfitableToInline(resolvedCall)) + return false; + // Otherwise, inline. return true; } diff --git a/mlir/test/Transforms/inlining-dump-default-pipeline.mlir b/mlir/test/Transforms/inlining-dump-default-pipeline.mlir index e2c31867a8e045..4f8638054206e8 100644 --- a/mlir/test/Transforms/inlining-dump-default-pipeline.mlir +++ b/mlir/test/Transforms/inlining-dump-default-pipeline.mlir @@ -1,2 +1,2 @@ // RUN: mlir-opt %s -pass-pipeline="builtin.module(inline)" -dump-pass-pipeline 2>&1 | FileCheck %s -// CHECK: builtin.module(inline{default-pipeline=canonicalize max-iterations=4 }) +// CHECK: builtin.module(inline{default-pipeline=canonicalize inlining-threshold=4294967295 max-iterations=4 }) diff --git a/mlir/test/Transforms/inlining-threshold.mlir b/mlir/test/Transforms/inlining-threshold.mlir new file mode 100644 index 00000000000000..b94115d8f26416 --- /dev/null +++ b/mlir/test/Transforms/inlining-threshold.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline='' inlining-threshold=100' -debug-only=inliner-pass 2>&1 | FileCheck %s + +// Check that inlining does not happen when the threshold is exceeded. +func.func @callee1(%arg : i32) -> i32 { + %v1 = arith.addi %arg, %arg : i32 + %v2 = arith.addi %v1, %arg : i32 + %v3 = arith.addi %v2, %arg : i32 + return %v3 : i32 +} + +// CHECK-LABEL: func @caller1 +func.func @caller1(%arg0 : i32) -> i32 { + // CHECK-NEXT: call @callee1 + // CHECK-NEXT: return + + %0 = call @callee1(%arg0) : (i32) -> i32 + return %0 : i32 +}