Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][mlir] Add profitability callback to the Inliner. #84258

Merged
merged 5 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions mlir/include/mlir/Transforms/Inliner.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,6 @@ class InlinerConfig {
/// of inlining decisions from the leafs to the roots of the callgraph.
class Inliner {
public:
using RunPipelineHelperTy = std::function<LogicalResult(
Pass &pass, OpPassManager &pipeline, Operation *op)>;

Inliner(Operation *op, CallGraph &cg, Pass &pass, AnalysisManager am,
RunPipelineHelperTy runPipelineHelper, const InlinerConfig &config)
: op(op), cg(cg), pass(pass), am(am),
runPipelineHelper(std::move(runPipelineHelper)), config(config) {}
Inliner(Inliner &) = delete;
void operator=(const Inliner &) = delete;

/// Perform inlining on a OpTrait::SymbolTable operation.
LogicalResult doInlining();

/// This struct represents a resolved call to a given callgraph node. Given
/// that the call does not actually contain a direct reference to the
/// Region(CallGraphNode) that it is dispatching to, we need to resolve them
Expand All @@ -94,7 +81,29 @@ class Inliner {
CallGraphNode *sourceNode, *targetNode;
};

protected:
using RunPipelineHelperTy = std::function<LogicalResult(
Pass &pass, OpPassManager &pipeline, Operation *op)>;

/// Type of the callback answering if it is profitable
/// to inline a callable operation at a call site.
/// It might be the case that the ResolvedCall does not provide
/// enough context to make the profitability decision, so
/// this hook's interface might need to be extended in future.
vzakhari marked this conversation as resolved.
Show resolved Hide resolved
using ProfitabilityCallbackTy = std::function<bool(const ResolvedCall &)>;

Inliner(Operation *op, CallGraph &cg, Pass &pass, AnalysisManager am,
RunPipelineHelperTy runPipelineHelper, const InlinerConfig &config,
ProfitabilityCallbackTy isProfitableToInline)
: op(op), cg(cg), pass(pass), am(am),
runPipelineHelper(std::move(runPipelineHelper)), config(config),
isProfitableToInline(std::move(isProfitableToInline)) {}
Inliner(Inliner &) = delete;
void operator=(const Inliner &) = delete;

/// Perform inlining on a OpTrait::SymbolTable operation.
LogicalResult doInlining();

private:
/// An OpTrait::SymbolTable operation to run the inlining on.
Operation *op;
/// A CallGraph analysis for the given operation.
Expand All @@ -108,12 +117,12 @@ class Inliner {
const RunPipelineHelperTy runPipelineHelper;
/// The inliner configuration parameters.
const InlinerConfig &config;
/// Returns true, if it is profitable to inline the callable operation
/// at the call site.
ProfitabilityCallbackTy isProfitableToInline;

private:
/// Forward declaration of the class providing the actual implementation.
class Impl;

public:
};
} // namespace mlir

Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,13 @@ def Inliner : Pass<"inline"> {
Option<"maxInliningIterations", "max-iterations", "unsigned",
/*default=*/"4",
"Maximum number of iterations when inlining within an SCC">,
Option<"inliningThreshold", "inlining-threshold", "unsigned",
/*default=*/"-1U",
"If the ratio between the number of the operations "
"in the callee and the number of the operations "
"in the caller exceeds this value (in percentage), "
"then the callee is not inlined even if it is legal "
"to inline it">,
];
}

Expand Down
38 changes: 37 additions & 1 deletion mlir/lib/Transforms/InlinerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ namespace mlir {
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE "inliner-pass"

using namespace mlir;

/// This function implements the inliner optimization pipeline.
Expand Down Expand Up @@ -88,6 +90,35 @@ InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
config.setOpPipelines(std::move(opPipelines));
}

// Return true if the inlining ratio does not exceed the threshold.
static bool isProfitableToInline(const Inliner::ResolvedCall &resolvedCall,
unsigned inliningThreshold) {
Region *callerRegion = resolvedCall.sourceNode->getCallableRegion();
Region *calleeRegion = resolvedCall.targetNode->getCallableRegion();

// We should not get external nodes here, but just return true
// for now to preserve the original behavior of the inliner pass.
if (!calleeRegion || !calleeRegion)
return true;

auto countOps = [](Region *region) {
unsigned count = 0;
region->walk([&](Operation *) { ++count; });
return count;
};

unsigned callerOps = countOps(callerRegion);

// Always inline empty callees (if it is possible at all).
if (callerOps == 0)
return true;

unsigned ratio = countOps(calleeRegion) * 100 / callerOps;
LLVM_DEBUG(llvm::dbgs() << "Callee / caller operation ratio (max: "
<< inliningThreshold << "%): " << ratio << "%\n");
return ratio <= inliningThreshold;
}

void InlinerPass::runOnOperation() {
CallGraph &cg = getAnalysis<CallGraph>();

Expand All @@ -100,9 +131,14 @@ void InlinerPass::runOnOperation() {
return signalPassFailure();
}

// By default, assume that any inlining is profitable.
auto profitabilityCb = [=](const Inliner::ResolvedCall &call) {
return isProfitableToInline(call, inliningThreshold);
};

// Get an instance of the inliner.
Inliner inliner(op, cg, *this, getAnalysisManager(), runPipelineHelper,
config);
config, profitabilityCb);

// Run the inlining.
if (failed(inliner.doInlining()))
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Transforms/Utils/Inliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,9 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
return false;

if (!inliner.isProfitableToInline(resolvedCall))
return false;

// Otherwise, inline.
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Transforms/inlining-dump-default-pipeline.mlir
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
// RUN: mlir-opt %s -pass-pipeline="builtin.module(inline)" -dump-pass-pipeline 2>&1 | FileCheck %s
// CHECK: builtin.module(inline{default-pipeline=canonicalize max-iterations=4 })
// CHECK: builtin.module(inline{default-pipeline=canonicalize inlining-threshold=4294967295 max-iterations=4 })
18 changes: 18 additions & 0 deletions mlir/test/Transforms/inlining-threshold.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline='' inlining-threshold=100' -debug-only=inliner-pass 2>&1 | FileCheck %s

// Check that inlining does not happen when the threshold is exceeded.
func.func @callee1(%arg : i32) -> i32 {
%v1 = arith.addi %arg, %arg : i32
%v2 = arith.addi %v1, %arg : i32
%v3 = arith.addi %v2, %arg : i32
return %v3 : i32
}

// CHECK-LABEL: func @caller1
func.func @caller1(%arg0 : i32) -> i32 {
// CHECK-NEXT: call @callee1
// CHECK-NEXT: return

%0 = call @callee1(%arg0) : (i32) -> i32
return %0 : i32
}
Loading