Skip to content

Commit

Permalink
[CodeGen] Support start/stop in CodeGenPassBuilder (llvm#70912)
Browse files Browse the repository at this point in the history
Add `-start/stop-before/after` support for CodeGenPassBuilder.
Part of llvm#69879.
  • Loading branch information
paperchalice authored Jan 18, 2024
1 parent 9d1dada commit baaf0c9
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 21 deletions.
133 changes: 113 additions & 20 deletions llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "llvm/CodeGen/ShadowStackGCLowering.h"
#include "llvm/CodeGen/SjLjEHPrepare.h"
#include "llvm/CodeGen/StackProtector.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/UnreachableBlockElim.h"
#include "llvm/CodeGen/WasmEHPrepare.h"
#include "llvm/CodeGen/WinEHPrepare.h"
Expand Down Expand Up @@ -176,73 +177,80 @@ template <typename DerivedT> class CodeGenPassBuilder {
// Function object to maintain state while adding codegen IR passes.
class AddIRPass {
public:
AddIRPass(ModulePassManager &MPM) : MPM(MPM) {}
AddIRPass(ModulePassManager &MPM, const DerivedT &PB) : MPM(MPM), PB(PB) {}
~AddIRPass() {
if (!FPM.isEmpty())
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
}

template <typename PassT> void operator()(PassT &&Pass) {
template <typename PassT>
void operator()(PassT &&Pass, StringRef Name = PassT::name()) {
static_assert((is_detected<is_function_pass_t, PassT>::value ||
is_detected<is_module_pass_t, PassT>::value) &&
"Only module pass and function pass are supported.");

if (!PB.runBeforeAdding(Name))
return;

// Add Function Pass
if constexpr (is_detected<is_function_pass_t, PassT>::value) {
FPM.addPass(std::forward<PassT>(Pass));

for (auto &C : PB.AfterCallbacks)
C(Name);
} else {
// Add Module Pass
if (!FPM.isEmpty()) {
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
FPM = FunctionPassManager();
}

MPM.addPass(std::forward<PassT>(Pass));

for (auto &C : PB.AfterCallbacks)
C(Name);
}
}

private:
ModulePassManager &MPM;
FunctionPassManager FPM;
const DerivedT &PB;
};

// Function object to maintain state while adding codegen machine passes.
class AddMachinePass {
public:
AddMachinePass(MachineFunctionPassManager &PM) : PM(PM) {}
AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB)
: PM(PM), PB(PB) {}

template <typename PassT> void operator()(PassT &&Pass) {
static_assert(
is_detected<has_key_t, PassT>::value,
"Machine function pass must define a static member variable `Key`.");
for (auto &C : BeforeCallbacks)
if (!C(&PassT::Key))
return;

if (!PB.runBeforeAdding(PassT::name()))
return;

PM.addPass(std::forward<PassT>(Pass));
for (auto &C : AfterCallbacks)
C(&PassT::Key);

for (auto &C : PB.AfterCallbacks)
C(PassT::name());
}

template <typename PassT> void insertPass(MachinePassKey *ID, PassT Pass) {
AfterCallbacks.emplace_back(
PB.AfterCallbacks.emplace_back(
[this, ID, Pass = std::move(Pass)](MachinePassKey *PassID) {
if (PassID == ID)
this->PM.addPass(std::move(Pass));
});
}

void disablePass(MachinePassKey *ID) {
BeforeCallbacks.emplace_back(
[ID](MachinePassKey *PassID) { return PassID != ID; });
}

MachineFunctionPassManager releasePM() { return std::move(PM); }

private:
MachineFunctionPassManager &PM;
SmallVector<llvm::unique_function<bool(MachinePassKey *)>, 4>
BeforeCallbacks;
SmallVector<llvm::unique_function<void(MachinePassKey *)>, 4>
AfterCallbacks;
const DerivedT &PB;
};

LLVMTargetMachine &TM;
Expand Down Expand Up @@ -473,20 +481,43 @@ template <typename DerivedT> class CodeGenPassBuilder {
const DerivedT &derived() const {
return static_cast<const DerivedT &>(*this);
}

bool runBeforeAdding(StringRef Name) const {
bool ShouldAdd = true;
for (auto &C : BeforeCallbacks)
ShouldAdd &= C(Name);
return ShouldAdd;
}

void setStartStopPasses(const TargetPassConfig::StartStopInfo &Info) const;

Error verifyStartStop(const TargetPassConfig::StartStopInfo &Info) const;

mutable SmallVector<llvm::unique_function<bool(StringRef)>, 4>
BeforeCallbacks;
mutable SmallVector<llvm::unique_function<void(StringRef)>, 4> AfterCallbacks;

/// Helper variable for `-start-before/-start-after/-stop-before/-stop-after`
mutable bool Started = true;
mutable bool Stopped = true;
};

template <typename Derived>
Error CodeGenPassBuilder<Derived>::buildPipeline(
ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
CodeGenFileType FileType) const {
AddIRPass addIRPass(MPM);
auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC);
if (!StartStopInfo)
return StartStopInfo.takeError();
setStartStopPasses(*StartStopInfo);
AddIRPass addIRPass(MPM, derived());
// `ProfileSummaryInfo` is always valid.
addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
addIRPass(RequireAnalysisPass<CollectorMetadataAnalysis, Module>());
addISelPasses(addIRPass);

AddMachinePass addPass(MFPM);
AddMachinePass addPass(MFPM, derived());
if (auto Err = addCoreISelPasses(addPass))
return std::move(Err);

Expand All @@ -499,6 +530,68 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
});

addPass(FreeMachineFunctionPass());
return verifyStartStop(*StartStopInfo);
}

template <typename Derived>
void CodeGenPassBuilder<Derived>::setStartStopPasses(
const TargetPassConfig::StartStopInfo &Info) const {
if (!Info.StartPass.empty()) {
Started = false;
BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StartAfter,
Count = 0u](StringRef ClassName) mutable {
if (Count == Info.StartInstanceNum) {
if (AfterFlag) {
AfterFlag = false;
Started = true;
}
return Started;
}

auto PassName = PIC->getPassNameForClassName(ClassName);
if (Info.StartPass == PassName && ++Count == Info.StartInstanceNum)
Started = !Info.StartAfter;

return Started;
});
}

if (!Info.StopPass.empty()) {
Stopped = false;
BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StopAfter,
Count = 0u](StringRef ClassName) mutable {
if (Count == Info.StopInstanceNum) {
if (AfterFlag) {
AfterFlag = false;
Stopped = true;
}
return !Stopped;
}

auto PassName = PIC->getPassNameForClassName(ClassName);
if (Info.StopPass == PassName && ++Count == Info.StopInstanceNum)
Stopped = !Info.StopAfter;
return !Stopped;
});
}
}

template <typename Derived>
Error CodeGenPassBuilder<Derived>::verifyStartStop(
const TargetPassConfig::StartStopInfo &Info) const {
if (Started && Stopped)
return Error::success();

if (!Started)
return make_error<StringError>(
"Can't find start pass \"" +
PIC->getPassNameForClassName(Info.StartPass) + "\".",
std::make_error_code(std::errc::invalid_argument));
if (!Stopped)
return make_error<StringError>(
"Can't find stop pass \"" +
PIC->getPassNameForClassName(Info.StopPass) + "\".",
std::make_error_code(std::errc::invalid_argument));
return Error::success();
}

Expand Down
15 changes: 15 additions & 0 deletions llvm/include/llvm/CodeGen/TargetPassConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "llvm/Pass.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Error.h"
#include <cassert>
#include <string>

Expand Down Expand Up @@ -176,6 +177,20 @@ class TargetPassConfig : public ImmutablePass {
static std::string
getLimitedCodeGenPipelineReason(const char *Separator = "/");

struct StartStopInfo {
bool StartAfter;
bool StopAfter;
unsigned StartInstanceNum;
unsigned StopInstanceNum;
StringRef StartPass;
StringRef StopPass;
};

/// Returns pass name in `-stop-before` or `-stop-after`
/// NOTE: New pass manager migration only
static Expected<StartStopInfo>
getStartStopInfo(PassInstrumentationCallbacks &PIC);

void setDisableVerify(bool Disable) { setOpt(DisableVerify, Disable); }

bool getEnableTailMerge() const { return EnableTailMerge; }
Expand Down
34 changes: 34 additions & 0 deletions llvm/lib/CodeGen/TargetPassConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,40 @@ void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
registerPartialPipelineCallback(PIC, LLVMTM);
}

Expected<TargetPassConfig::StartStopInfo>
TargetPassConfig::getStartStopInfo(PassInstrumentationCallbacks &PIC) {
auto [StartBefore, StartBeforeInstanceNum] =
getPassNameAndInstanceNum(StartBeforeOpt);
auto [StartAfter, StartAfterInstanceNum] =
getPassNameAndInstanceNum(StartAfterOpt);
auto [StopBefore, StopBeforeInstanceNum] =
getPassNameAndInstanceNum(StopBeforeOpt);
auto [StopAfter, StopAfterInstanceNum] =
getPassNameAndInstanceNum(StopAfterOpt);

if (!StartBefore.empty() && !StartAfter.empty())
return make_error<StringError>(
Twine(StartBeforeOptName) + " and " + StartAfterOptName + " specified!",
std::make_error_code(std::errc::invalid_argument));
if (!StopBefore.empty() && !StopAfter.empty())
return make_error<StringError>(
Twine(StopBeforeOptName) + " and " + StopAfterOptName + " specified!",
std::make_error_code(std::errc::invalid_argument));

StartStopInfo Result;
Result.StartPass = StartBefore.empty() ? StartAfter : StartBefore;
Result.StopPass = StopBefore.empty() ? StopAfter : StopBefore;
Result.StartInstanceNum =
StartBefore.empty() ? StartAfterInstanceNum : StartBeforeInstanceNum;
Result.StopInstanceNum =
StopBefore.empty() ? StopAfterInstanceNum : StopBeforeInstanceNum;
Result.StartAfter = !StartAfter.empty();
Result.StopAfter = !StopAfter.empty();
Result.StartInstanceNum += Result.StartInstanceNum == 0;
Result.StopInstanceNum += Result.StopInstanceNum == 0;
return Result;
}

// Out of line constructor provides default values for pass options and
// registers all common codegen passes.
TargetPassConfig::TargetPassConfig(LLVMTargetMachine &TM, PassManagerBase &pm)
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
#include "llvm/CodeGen/ShadowStackGCLowering.h"
#include "llvm/CodeGen/SjLjEHPrepare.h"
#include "llvm/CodeGen/StackProtector.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TypePromotion.h"
#include "llvm/CodeGen/WasmEHPrepare.h"
#include "llvm/CodeGen/WinEHPrepare.h"
Expand Down Expand Up @@ -316,7 +317,8 @@ namespace {
/// We currently only use this for --print-before/after.
bool shouldPopulateClassToPassNames() {
return PrintPipelinePasses || !printBeforePasses().empty() ||
!printAfterPasses().empty() || !isFilterPassesEmpty();
!printAfterPasses().empty() || !isFilterPassesEmpty() ||
TargetPassConfig::hasLimitedCodeGenPipeline();
}

// A pass for testing -print-on-crash.
Expand Down
41 changes: 41 additions & 0 deletions llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,45 @@ TEST_F(CodeGenPassBuilderTest, basic) {
EXPECT_EQ(MIRPipeline, ExpectedMIRPipeline);
}

// TODO: Move this to lit test when llc support new pm.
TEST_F(CodeGenPassBuilderTest, start_stop) {
static const char *argv[] = {
"test",
"-start-after=no-op-module",
"-stop-before=no-op-function,2",
};
int argc = std::size(argv);
cl::ParseCommandLineOptions(argc, argv);

LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;

PassInstrumentationCallbacks PIC;
DummyCodeGenPassBuilder CGPB(*TM, getCGPassBuilderOption(), &PIC);
PipelineTuningOptions PTO;
PassBuilder PB(TM.get(), PTO, std::nullopt, &PIC);

PB.registerModuleAnalyses(MAM);
PB.registerCGSCCAnalyses(CGAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);

ModulePassManager MPM;
MachineFunctionPassManager MFPM;

Error Err =
CGPB.buildPipeline(MPM, MFPM, outs(), nullptr, CodeGenFileType::Null);
EXPECT_FALSE(Err);
std::string IRPipeline;
raw_string_ostream IROS(IRPipeline);
MPM.printPipeline(IROS, [&PIC](StringRef Name) {
auto PassName = PIC.getPassNameForClassName(Name);
return PassName.empty() ? Name : PassName;
});
EXPECT_EQ(IRPipeline, "function(no-op-function)");
}

} // namespace

0 comments on commit baaf0c9

Please sign in to comment.