From a9e617e2b449a7460155555d1fe0e978454f3561 Mon Sep 17 00:00:00 2001 From: wangpc Date: Fri, 24 Nov 2023 19:45:06 +0800 Subject: [PATCH] [SelectionDAG] Add space-optimized forms of OPC_CheckPatternPredicate We record the usage of each `PatternPredicate` and sort them by usage. For the top 8 `PatternPredicate`s, we will emit a `OPC_CheckPatternPredicateN` to save one byte. The old `OPC_CheckPatternPredicate2` is renamed to `OPC_CheckPatternPredicateTwoByte`. Overall this reduces the llc binary size with all in-tree targets by about 93K. --- llvm/include/llvm/CodeGen/SelectionDAGISel.h | 8 +++++ .../CodeGen/SelectionDAG/SelectionDAGISel.cpp | 34 ++++++++++++++----- llvm/utils/TableGen/DAGISelMatcherEmitter.cpp | 26 +++++++++----- 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SelectionDAGISel.h b/llvm/include/llvm/CodeGen/SelectionDAGISel.h index 99ce658e7eb711..e4d90f6e898fe8 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGISel.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGISel.h @@ -159,7 +159,15 @@ class SelectionDAGISel : public MachineFunctionPass { OPC_CheckChild2Same, OPC_CheckChild3Same, OPC_CheckPatternPredicate, + OPC_CheckPatternPredicate0, + OPC_CheckPatternPredicate1, OPC_CheckPatternPredicate2, + OPC_CheckPatternPredicate3, + OPC_CheckPatternPredicate4, + OPC_CheckPatternPredicate5, + OPC_CheckPatternPredicate6, + OPC_CheckPatternPredicate7, + OPC_CheckPatternPredicateTwoByte, OPC_CheckPredicate, OPC_CheckPredicateWithOperands, OPC_CheckOpcode, diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp index 344dc8d8a9b677..678d273e4bd605 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -2697,9 +2697,14 @@ LLVM_ATTRIBUTE_ALWAYS_INLINE static bool CheckChildSame( /// CheckPatternPredicate - Implements OP_CheckPatternPredicate. LLVM_ATTRIBUTE_ALWAYS_INLINE static bool -CheckPatternPredicate(const unsigned char *MatcherTable, unsigned &MatcherIndex, - const SelectionDAGISel &SDISel, bool TwoBytePredNo) { - unsigned PredNo = MatcherTable[MatcherIndex++]; +CheckPatternPredicate(unsigned Opcode, const unsigned char *MatcherTable, + unsigned &MatcherIndex, const SelectionDAGISel &SDISel) { + bool TwoBytePredNo = + Opcode == SelectionDAGISel::OPC_CheckPatternPredicateTwoByte; + unsigned PredNo = + TwoBytePredNo || Opcode == SelectionDAGISel::OPC_CheckPatternPredicate + ? MatcherTable[MatcherIndex++] + : Opcode - SelectionDAGISel::OPC_CheckPatternPredicate0; if (TwoBytePredNo) PredNo |= MatcherTable[MatcherIndex++] << 8; return SDISel.CheckPatternPredicate(PredNo); @@ -2851,10 +2856,16 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table, Table[Index-1] - SelectionDAGISel::OPC_CheckChild0Same); return Index; case SelectionDAGISel::OPC_CheckPatternPredicate: + case SelectionDAGISel::OPC_CheckPatternPredicate0: + case SelectionDAGISel::OPC_CheckPatternPredicate1: case SelectionDAGISel::OPC_CheckPatternPredicate2: - Result = !::CheckPatternPredicate( - Table, Index, SDISel, - Table[Index - 1] == SelectionDAGISel::OPC_CheckPatternPredicate2); + case SelectionDAGISel::OPC_CheckPatternPredicate3: + case SelectionDAGISel::OPC_CheckPatternPredicate4: + case SelectionDAGISel::OPC_CheckPatternPredicate5: + case SelectionDAGISel::OPC_CheckPatternPredicate6: + case SelectionDAGISel::OPC_CheckPatternPredicate7: + case SelectionDAGISel::OPC_CheckPatternPredicateTwoByte: + Result = !::CheckPatternPredicate(Opcode, Table, Index, SDISel); return Index; case SelectionDAGISel::OPC_CheckPredicate: Result = !::CheckNodePredicate(Table, Index, SDISel, N.getNode()); @@ -3336,9 +3347,16 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, continue; case OPC_CheckPatternPredicate: + case OPC_CheckPatternPredicate0: + case OPC_CheckPatternPredicate1: case OPC_CheckPatternPredicate2: - if (!::CheckPatternPredicate(MatcherTable, MatcherIndex, *this, - Opcode == OPC_CheckPatternPredicate2)) + case OPC_CheckPatternPredicate3: + case OPC_CheckPatternPredicate4: + case OPC_CheckPatternPredicate5: + case OPC_CheckPatternPredicate6: + case OPC_CheckPatternPredicate7: + case OPC_CheckPatternPredicateTwoByte: + if (!::CheckPatternPredicate(Opcode, MatcherTable, MatcherIndex, *this)) break; continue; case OPC_CheckPredicate: diff --git a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp index e460a2804c6649..a3e2facf948e89 100644 --- a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp +++ b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp @@ -60,7 +60,6 @@ class MatcherTableEmitter { // all the patterns with "identical" predicates. StringMap> NodePredicatesByCodeToRun; - StringMap PatternPredicateMap; std::vector PatternPredicates; std::vector ComplexPatterns; @@ -87,6 +86,8 @@ class MatcherTableEmitter { : CGP(cgp), OpcodeCounts(Matcher::HighestKind + 1, 0) { // Record the usage of ComplexPattern. DenseMap ComplexPatternUsage; + // Record the usage of PatternPredicate. + std::map PatternPredicateUsage; // Iterate the whole MatcherTable once and do some statistics. std::function Statistic = [&](const Matcher *N) { @@ -102,6 +103,8 @@ class MatcherTableEmitter { Statistic(STM->getCaseMatcher(I)); else if (auto *CPM = dyn_cast(N)) ++ComplexPatternUsage[&CPM->getPattern()]; + else if (auto *CPPM = dyn_cast(N)) + ++PatternPredicateUsage[CPPM->getPredicate()]; N = N->getNext(); } }; @@ -114,6 +117,14 @@ class MatcherTableEmitter { [](const auto &A, const auto &B) { return A.second > B.second; }); for (const auto &ComplexPattern : ComplexPatternList) ComplexPatterns.push_back(ComplexPattern.first); + + // Sort PatternPredicates by usage. + std::vector> PatternPredicateList( + PatternPredicateUsage.begin(), PatternPredicateUsage.end()); + sort(PatternPredicateList, + [](const auto &A, const auto &B) { return A.second > B.second; }); + for (const auto &PatternPredicate : PatternPredicateList) + PatternPredicates.push_back(PatternPredicate.first); } unsigned EmitMatcherList(const Matcher *N, const unsigned Indent, @@ -167,12 +178,7 @@ class MatcherTableEmitter { } unsigned getPatternPredicate(StringRef PredName) { - unsigned &Entry = PatternPredicateMap[PredName]; - if (Entry == 0) { - PatternPredicates.push_back(PredName.str()); - Entry = PatternPredicates.size(); - } - return Entry-1; + return llvm::find(PatternPredicates, PredName) - PatternPredicates.begin(); } unsigned getComplexPat(const ComplexPattern &P) { return llvm::find(ComplexPatterns, &P) - ComplexPatterns.begin(); @@ -510,13 +516,15 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx, StringRef Pred = cast(N)->getPredicate(); unsigned PredNo = getPatternPredicate(Pred); if (PredNo > 255) - OS << "OPC_CheckPatternPredicate2, TARGET_VAL(" << PredNo << "),"; + OS << "OPC_CheckPatternPredicateTwoByte, TARGET_VAL(" << PredNo << "),"; + else if (PredNo < 8) + OS << "OPC_CheckPatternPredicate" << PredNo << ','; else OS << "OPC_CheckPatternPredicate, " << PredNo << ','; if (!OmitComments) OS << " // " << Pred; OS << '\n'; - return 2 + (PredNo > 255); + return 2 + (PredNo > 255) - (PredNo < 8); } case Matcher::CheckPredicate: { TreePredicateFn Pred = cast(N)->getPredicate();