Skip to content

Commit

Permalink
[SelectionDAG] Add space-optimized forms of OPC_CheckPatternPredicate
Browse files Browse the repository at this point in the history
We record the usage of each `PatternPredicate` and sort them by
usage.

For the top 8 `PatternPredicate`s, we will emit a
`OPC_CheckPatternPredicateN` to save one byte.

The old `OPC_CheckPatternPredicate2` is renamed to
`OPC_CheckPatternPredicateTwoByte`.

Overall this reduces the llc binary size with all in-tree targets by
about 93K.
  • Loading branch information
wangpc-pp committed Jan 11, 2024
1 parent 211abe3 commit a9e617e
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
8 changes: 8 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGISel.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,15 @@ class SelectionDAGISel : public MachineFunctionPass {
OPC_CheckChild2Same,
OPC_CheckChild3Same,
OPC_CheckPatternPredicate,
OPC_CheckPatternPredicate0,
OPC_CheckPatternPredicate1,
OPC_CheckPatternPredicate2,
OPC_CheckPatternPredicate3,
OPC_CheckPatternPredicate4,
OPC_CheckPatternPredicate5,
OPC_CheckPatternPredicate6,
OPC_CheckPatternPredicate7,
OPC_CheckPatternPredicateTwoByte,
OPC_CheckPredicate,
OPC_CheckPredicateWithOperands,
OPC_CheckOpcode,
Expand Down
34 changes: 26 additions & 8 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2697,9 +2697,14 @@ LLVM_ATTRIBUTE_ALWAYS_INLINE static bool CheckChildSame(

/// CheckPatternPredicate - Implements OP_CheckPatternPredicate.
LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
CheckPatternPredicate(const unsigned char *MatcherTable, unsigned &MatcherIndex,
const SelectionDAGISel &SDISel, bool TwoBytePredNo) {
unsigned PredNo = MatcherTable[MatcherIndex++];
CheckPatternPredicate(unsigned Opcode, const unsigned char *MatcherTable,
unsigned &MatcherIndex, const SelectionDAGISel &SDISel) {
bool TwoBytePredNo =
Opcode == SelectionDAGISel::OPC_CheckPatternPredicateTwoByte;
unsigned PredNo =
TwoBytePredNo || Opcode == SelectionDAGISel::OPC_CheckPatternPredicate
? MatcherTable[MatcherIndex++]
: Opcode - SelectionDAGISel::OPC_CheckPatternPredicate0;
if (TwoBytePredNo)
PredNo |= MatcherTable[MatcherIndex++] << 8;
return SDISel.CheckPatternPredicate(PredNo);
Expand Down Expand Up @@ -2851,10 +2856,16 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
Table[Index-1] - SelectionDAGISel::OPC_CheckChild0Same);
return Index;
case SelectionDAGISel::OPC_CheckPatternPredicate:
case SelectionDAGISel::OPC_CheckPatternPredicate0:
case SelectionDAGISel::OPC_CheckPatternPredicate1:
case SelectionDAGISel::OPC_CheckPatternPredicate2:
Result = !::CheckPatternPredicate(
Table, Index, SDISel,
Table[Index - 1] == SelectionDAGISel::OPC_CheckPatternPredicate2);
case SelectionDAGISel::OPC_CheckPatternPredicate3:
case SelectionDAGISel::OPC_CheckPatternPredicate4:
case SelectionDAGISel::OPC_CheckPatternPredicate5:
case SelectionDAGISel::OPC_CheckPatternPredicate6:
case SelectionDAGISel::OPC_CheckPatternPredicate7:
case SelectionDAGISel::OPC_CheckPatternPredicateTwoByte:
Result = !::CheckPatternPredicate(Opcode, Table, Index, SDISel);
return Index;
case SelectionDAGISel::OPC_CheckPredicate:
Result = !::CheckNodePredicate(Table, Index, SDISel, N.getNode());
Expand Down Expand Up @@ -3336,9 +3347,16 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
continue;

case OPC_CheckPatternPredicate:
case OPC_CheckPatternPredicate0:
case OPC_CheckPatternPredicate1:
case OPC_CheckPatternPredicate2:
if (!::CheckPatternPredicate(MatcherTable, MatcherIndex, *this,
Opcode == OPC_CheckPatternPredicate2))
case OPC_CheckPatternPredicate3:
case OPC_CheckPatternPredicate4:
case OPC_CheckPatternPredicate5:
case OPC_CheckPatternPredicate6:
case OPC_CheckPatternPredicate7:
case OPC_CheckPatternPredicateTwoByte:
if (!::CheckPatternPredicate(Opcode, MatcherTable, MatcherIndex, *this))
break;
continue;
case OPC_CheckPredicate:
Expand Down
26 changes: 17 additions & 9 deletions llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class MatcherTableEmitter {
// all the patterns with "identical" predicates.
StringMap<TinyPtrVector<TreePattern *>> NodePredicatesByCodeToRun;

StringMap<unsigned> PatternPredicateMap;
std::vector<std::string> PatternPredicates;

std::vector<const ComplexPattern*> ComplexPatterns;
Expand All @@ -87,6 +86,8 @@ class MatcherTableEmitter {
: CGP(cgp), OpcodeCounts(Matcher::HighestKind + 1, 0) {
// Record the usage of ComplexPattern.
DenseMap<const ComplexPattern *, unsigned> ComplexPatternUsage;
// Record the usage of PatternPredicate.
std::map<StringRef, unsigned> PatternPredicateUsage;

// Iterate the whole MatcherTable once and do some statistics.
std::function<void(const Matcher *)> Statistic = [&](const Matcher *N) {
Expand All @@ -102,6 +103,8 @@ class MatcherTableEmitter {
Statistic(STM->getCaseMatcher(I));
else if (auto *CPM = dyn_cast<CheckComplexPatMatcher>(N))
++ComplexPatternUsage[&CPM->getPattern()];
else if (auto *CPPM = dyn_cast<CheckPatternPredicateMatcher>(N))
++PatternPredicateUsage[CPPM->getPredicate()];
N = N->getNext();
}
};
Expand All @@ -114,6 +117,14 @@ class MatcherTableEmitter {
[](const auto &A, const auto &B) { return A.second > B.second; });
for (const auto &ComplexPattern : ComplexPatternList)
ComplexPatterns.push_back(ComplexPattern.first);

// Sort PatternPredicates by usage.
std::vector<std::pair<std::string, unsigned>> PatternPredicateList(
PatternPredicateUsage.begin(), PatternPredicateUsage.end());
sort(PatternPredicateList,
[](const auto &A, const auto &B) { return A.second > B.second; });
for (const auto &PatternPredicate : PatternPredicateList)
PatternPredicates.push_back(PatternPredicate.first);
}

unsigned EmitMatcherList(const Matcher *N, const unsigned Indent,
Expand Down Expand Up @@ -167,12 +178,7 @@ class MatcherTableEmitter {
}

unsigned getPatternPredicate(StringRef PredName) {
unsigned &Entry = PatternPredicateMap[PredName];
if (Entry == 0) {
PatternPredicates.push_back(PredName.str());
Entry = PatternPredicates.size();
}
return Entry-1;
return llvm::find(PatternPredicates, PredName) - PatternPredicates.begin();
}
unsigned getComplexPat(const ComplexPattern &P) {
return llvm::find(ComplexPatterns, &P) - ComplexPatterns.begin();
Expand Down Expand Up @@ -510,13 +516,15 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
StringRef Pred = cast<CheckPatternPredicateMatcher>(N)->getPredicate();
unsigned PredNo = getPatternPredicate(Pred);
if (PredNo > 255)
OS << "OPC_CheckPatternPredicate2, TARGET_VAL(" << PredNo << "),";
OS << "OPC_CheckPatternPredicateTwoByte, TARGET_VAL(" << PredNo << "),";
else if (PredNo < 8)
OS << "OPC_CheckPatternPredicate" << PredNo << ',';
else
OS << "OPC_CheckPatternPredicate, " << PredNo << ',';
if (!OmitComments)
OS << " // " << Pred;
OS << '\n';
return 2 + (PredNo > 255);
return 2 + (PredNo > 255) - (PredNo < 8);
}
case Matcher::CheckPredicate: {
TreePredicateFn Pred = cast<CheckPredicateMatcher>(N)->getPredicate();
Expand Down

0 comments on commit a9e617e

Please sign in to comment.