Skip to content

Commit

Permalink
Arm64: Combine if conditions into compare chains (#79283)
Browse files Browse the repository at this point in the history
Add a new stage optOptimizeCompareChainCondBlock in pass optOptimizeBools.

This aims to reduced the number of conditional jumps by joining cases when multiple conditions gate the execution of a block.

Example 1:
If ( a > b || c == d) { x = y; }

Will be represented in IR as:

 ------------ BB01 -> BB03 (cond), succs={BB02,BB03}
 *  JTRUE (GT a,b)

 ------------ BB02 -> BB04 (cond), preds={BB01} succs={BB03,BB04}
 *  JTRUE (NE c,d)

 ------------ BB03, preds={BB01, BB02} succs={BB04}
 *  ASG (x,y)
These operands will be combined into a single AND in the first block (with the first
condition inverted), wrapped by the test condition (NE(...,0)). Giving:

 ------------ BB01 -> BB03 (cond), succs={BB03,BB04}
 *  JTRUE (NE (AND (LE a,b), (NE c,d)), 0)

 ------------ BB03, preds={BB01} succs={BB04}
 *  ASG x,y
Example 2:
If ( a > b && c == d) { x = y; } else { x = z; }

Here the && conditions are connected via an OR. After the pass:

 ------------ BB01 -> BB03 (cond), succs={BB03,BB04}
 *  JTRUE (NE (OR (LE a,b), (NE c,d)), 0)

 ------------ BB03, preds={BB01} succs={BB05}
 *  ASG x,y

 ------------ BB04, preds={BB01} succs={BB05}
 *  ASG x,z
Example 3:
If ( a > b || c == d || e < f ) { x = y; }
The first pass of the optimization will combine two of the conditions. The
second pass will then combine remaining condition the earlier chain.

 ------------ BB01 -> BB03 (cond), succs={BB03,BB04}
 *  JTRUE (NE (OR ((NE (OR (NE c,d), (GE e,f)), 0), (LE a,b))), 0)

 ------------ BB03, preds={BB01} succs={BB04}
 *  ASG x,y
This optimization means that every condition within the IF statement is always evaluated,
as opposed to stopping at the first positive match.
Theoretically there is no maximum limit on the size of the generated chain. Therefore cost
checking is used to limit the maximum number of conditions that can be chained together.

Currently the cost checking limits to a maximum of three simple conditions. This is the same behaviour as GCC. Note that LLVM allows chains of much longer length.
  • Loading branch information
a74nh authored Mar 21, 2023
1 parent 6439980 commit e431c00
Show file tree
Hide file tree
Showing 6 changed files with 474 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -9802,6 +9802,7 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
STRESS_MODE(MERGED_RETURNS) \
STRESS_MODE(BB_PROFILE) \
STRESS_MODE(OPT_BOOLS_GC) \
STRESS_MODE(OPT_BOOLS_COMPARE_CHAIN_COST) \
STRESS_MODE(REMORPH_TREES) \
STRESS_MODE(64RSLT_MUL) \
STRESS_MODE(DO_WHILE_LOOPS) \
Expand Down
278 changes: 277 additions & 1 deletion src/coreclr/jit/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9100,6 +9100,7 @@ class OptBoolsDsc

public:
bool optOptimizeBoolsCondBlock();
bool optOptimizeCompareChainCondBlock();
bool optOptimizeBoolsReturnBlock(BasicBlock* b3);
#ifdef DEBUG
void optOptimizeBoolsGcStress();
Expand All @@ -9110,6 +9111,7 @@ class OptBoolsDsc
GenTree* optIsBoolComp(OptTestInfo* pOptTest);
bool optOptimizeBoolsChkTypeCostCond();
void optOptimizeBoolsUpdateTrees();
bool FindCompareChain(GenTree* condition, bool* isTestCondition);
};

//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -9335,6 +9337,267 @@ bool OptBoolsDsc::optOptimizeBoolsCondBlock()
return true;
}

//-----------------------------------------------------------------------------
// FindCompareChain: Check if the given condition is a compare chain.
//
// Arguments:
// condition: Condition to check.
// isTestCondition: Returns true if condition is but is not a compare chain.
//
// Returns:
// true if chain optimization is a compare chain.
//
// Assumptions:
// m_b1 and m_b2 are set on entry.
//

bool OptBoolsDsc::FindCompareChain(GenTree* condition, bool* isTestCondition)
{
GenTree* condOp1 = condition->gtGetOp1();
GenTree* condOp2 = condition->gtGetOp2();

*isTestCondition = false;

if (condition->OperIs(GT_EQ, GT_NE) && condOp2->IsIntegralConst())
{
ssize_t condOp2Value = condOp2->AsIntCon()->IconValue();

if (condOp2Value == 0)
{
// Found a EQ/NE(...,0). Does it contain a compare chain (ie - conditions that have
// previously been combined by optOptimizeCompareChainCondBlock) or is it a test condition
// that will be optimised to cbz/cbnz during lowering?

if (condOp1->OperIs(GT_AND, GT_OR))
{
// Check that the second operand of AND/OR ends with a compare operation, as this will be
// the condition the new link in the chain will connect with.
if (condOp1->gtGetOp2()->OperIsCmpCompare() && varTypeIsIntegralOrI(condOp1->gtGetOp2()->gtGetOp1()))
{
return true;
}
}

*isTestCondition = true;
}
else if (condOp1->OperIs(GT_AND) && isPow2(static_cast<target_size_t>(condOp2Value)) &&
condOp1->gtGetOp2()->IsIntegralConst(condOp2Value))
{
// Found a EQ/NE(AND(...,n),n) which will be optimized to tbz/tbnz during lowering.
*isTestCondition = true;
}
}

return false;
}

//-----------------------------------------------------------------------------
// optOptimizeCompareChainCondBlock: Create a chain when when both m_b1 and m_b2 are BBJ_COND.
//
// Returns:
// true if chain optimization is done and m_b1 and m_b2 are folded into m_b1, else false.
//
// Assumptions:
// m_b1 and m_b2 are set on entry.
//
// Notes:
//
// This aims to reduced the number of conditional jumps by joining cases when multiple
// conditions gate the execution of a block.
//
// Example 1:
// If ( a > b || c == d) { x = y; }
//
// Will be represented in IR as:
//
// ------------ BB01 -> BB03 (cond), succs={BB02,BB03}
// * JTRUE (GT a,b)
//
// ------------ BB02 -> BB04 (cond), preds={BB01} succs={BB03,BB04}
// * JTRUE (NE c,d)
//
// ------------ BB03, preds={BB01, BB02} succs={BB04}
// * ASG (x,y)
//
// These operands will be combined into a single AND in the first block (with the first
// condition inverted), wrapped by the test condition (NE(...,0)). Giving:
//
// ------------ BB01 -> BB03 (cond), succs={BB03,BB04}
// * JTRUE (NE (AND (LE a,b), (NE c,d)), 0)
//
// ------------ BB03, preds={BB01} succs={BB04}
// * ASG x,y
//
//
// Example 2:
// If ( a > b && c == d) { x = y; } else { x = z; }
//
// Here the && conditions are connected via an OR. After the pass:
//
// ------------ BB01 -> BB03 (cond), succs={BB03,BB04}
// * JTRUE (NE (OR (LE a,b), (NE c,d)), 0)
//
// ------------ BB03, preds={BB01} succs={BB05}
// * ASG x,y
//
// ------------ BB04, preds={BB01} succs={BB05}
// * ASG x,z
//
//
// Example 3:
// If ( a > b || c == d || e < f ) { x = y; }
// The first pass of the optimization will combine two of the conditions. The
// second pass will then combine remaining condition the earlier chain.
//
// ------------ BB01 -> BB03 (cond), succs={BB03,BB04}
// * JTRUE (NE (OR ((NE (OR (NE c,d), (GE e,f)), 0), (LE a,b))), 0)
//
// ------------ BB03, preds={BB01} succs={BB04}
// * ASG x,y
//
//
// This optimization means that every condition within the IF statement is always evaluated,
// as opposed to stopping at the first positive match.
// Theoretically there is no maximum limit on the size of the generated chain. Therefore cost
// checking is used to limit the maximum number of conditions that can be chained together.
//
bool OptBoolsDsc::optOptimizeCompareChainCondBlock()
{
assert((m_b1 != nullptr) && (m_b2 != nullptr) && (m_b3 == nullptr));
m_t3 = nullptr;

bool foundEndOfOrConditions = false;
if ((m_b1->bbNext == m_b2) && (m_b1->bbJumpDest == m_b2->bbNext))
{
// Found the end of two (or more) conditions being ORed together.
// The final condition has been inverted.
foundEndOfOrConditions = true;
}
else if ((m_b1->bbNext == m_b2) && (m_b1->bbJumpDest == m_b2->bbJumpDest))
{
// Found two conditions connected together.
}
else
{
return false;
}

Statement* const s1 = optOptimizeBoolsChkBlkCond();
if (s1 == nullptr)
{
return false;
}
Statement* s2 = m_b2->firstStmt();

assert(m_testInfo1.testTree->OperIs(GT_JTRUE));
GenTree* cond1 = m_testInfo1.testTree->gtGetOp1();
assert(m_testInfo2.testTree->OperIs(GT_JTRUE));
GenTree* cond2 = m_testInfo2.testTree->gtGetOp1();

// Ensure both conditions are suitable.
if (!cond1->OperIsCompare() || !cond2->OperIsCompare())
{
return false;
}

// Ensure there are no additional side effects.
if ((cond1->gtFlags & (GTF_SIDE_EFFECT | GTF_ORDER_SIDEEFF)) != 0 ||
(cond2->gtFlags & (GTF_SIDE_EFFECT | GTF_ORDER_SIDEEFF)) != 0)
{
return false;
}

// Integer compares only for now (until support for Arm64 fccmp instruction is added)
if (varTypeIsFloating(cond1->gtGetOp1()) || varTypeIsFloating(cond2->gtGetOp1()))
{
return false;
}

// Check for previously optimized compare chains.
bool op1IsTestCond;
bool op2IsTestCond;
bool op1IsCondChain = FindCompareChain(cond1, &op1IsTestCond);
bool op2IsCondChain = FindCompareChain(cond2, &op2IsTestCond);

// Avoid cases where optimizations in lowering will produce better code than optimizing here.
if (op1IsTestCond || op2IsTestCond)
{
return false;
}

// Combining conditions means that all conditions are always fully evaluated.
// Put a limit on the max size that can be combined.
if (!m_comp->compStressCompile(Compiler::STRESS_OPT_BOOLS_COMPARE_CHAIN_COST, 25))
{
int op1Cost = cond1->GetCostEx();
int op2Cost = cond2->GetCostEx();
// The cost of combing three simple conditions is 32.
int maxOp1Cost = op1IsCondChain ? 31 : 7;
int maxOp2Cost = op2IsCondChain ? 31 : 7;

// Cost to allow for chain size of three.
if (op1Cost > maxOp1Cost || op2Cost > maxOp2Cost)
{
JITDUMP("Skipping CompareChainCond that will evaluate conditions unconditionally at costs %d,%d\n", op1Cost,
op2Cost);
return false;
}
}

// Remove the first JTRUE statement.
constexpr bool isUnlink = true;
m_comp->fgRemoveStmt(m_b1, s1 DEBUGARG(isUnlink));

// Invert the condition.
if (foundEndOfOrConditions)
{
GenTree* revCond = m_comp->gtReverseCond(cond1);
assert(cond1 == revCond); // Ensure `gtReverseCond` did not create a new node.
}

// Join the two conditions together
genTreeOps chainedOper = foundEndOfOrConditions ? GT_AND : GT_OR;
GenTree* chainedConditions = m_comp->gtNewOperNode(chainedOper, TYP_INT, cond1, cond2);
cond1->gtFlags &= ~GTF_RELOP_JMP_USED;
cond2->gtFlags &= ~GTF_RELOP_JMP_USED;
chainedConditions->gtFlags |= (GTF_RELOP_JMP_USED | GTF_DONT_CSE);

// Add a test condition onto the front of the chain
GenTree* testcondition =
m_comp->gtNewOperNode(GT_NE, TYP_INT, chainedConditions, m_comp->gtNewZeroConNode(TYP_INT));

// Wire the chain into the second block
m_testInfo2.testTree->AsOp()->gtOp1 = testcondition;
m_testInfo2.testTree->AsOp()->gtFlags |= (testcondition->gtFlags & GTF_ALL_EFFECT);
m_comp->gtSetEvalOrder(m_testInfo2.testTree);
m_comp->fgSetStmtSeq(s2);

// Update the flow.
m_comp->fgRemoveRefPred(m_b1->bbJumpDest, m_b1);
m_b1->bbJumpKind = BBJ_NONE;

// Fixup flags.
m_b2->bbFlags |= (m_b1->bbFlags & BBF_COPY_PROPAGATE);

// Join the two blocks. This is done now to ensure that additional conditions can be chained.
if (m_comp->fgCanCompactBlocks(m_b1, m_b2))
{
m_comp->fgCompactBlocks(m_b1, m_b2);
}

#ifdef DEBUG
if (m_comp->verbose)
{
JITDUMP("\nCombined conditions " FMT_BB " and " FMT_BB " into %s chain :\n", m_b1->bbNum, m_b2->bbNum,
GenTree::OpName(chainedOper));
m_comp->fgDumpBlock(m_b1);
JITDUMP("\n");
}
#endif

return true;
}

//-----------------------------------------------------------------------------
// optOptimizeBoolsChkBlkCond: Checks block conditions if it can be boolean optimized
//
Expand Down Expand Up @@ -10076,6 +10339,7 @@ PhaseStatus Compiler::optOptimizeBools()
}
#endif
bool change = false;
bool retry = false;
unsigned numCond = 0;
unsigned numReturn = 0;
unsigned numPasses = 0;
Expand All @@ -10086,8 +10350,10 @@ PhaseStatus Compiler::optOptimizeBools()
numPasses++;
change = false;

for (BasicBlock* const b1 : Blocks())
for (BasicBlock* b1 = fgFirstBB; b1 != nullptr; b1 = retry ? b1 : b1->bbNext)
{
retry = false;

// We're only interested in conditional jumps here

if (b1->bbJumpKind != BBJ_COND)
Expand Down Expand Up @@ -10127,6 +10393,16 @@ PhaseStatus Compiler::optOptimizeBools()
change = true;
numCond++;
}
#ifdef TARGET_ARM64
else if (optBoolsDsc.optOptimizeCompareChainCondBlock())
{
// The optimization will have merged b1 and b2. Retry the loop so that
// b1 and b2->bbNext can be tested.
change = true;
retry = true;
numCond++;
}
#endif
}
else if (b2->bbJumpKind == BBJ_RETURN)
{
Expand Down
Loading

0 comments on commit e431c00

Please sign in to comment.