Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT: Have lowering set up IR for post-indexed addressing and make strength reduced IV updates amenable to post-indexed addressing #105185

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 156 additions & 9 deletions src/coreclr/jit/inductionvariableopts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,11 @@ class StrengthReductionContext
bool CheckAdvancedCursors(ArrayStack<CursorInfo>* cursors, ScevAddRec** nextIV);
bool StaysWithinManagedObject(ArrayStack<CursorInfo>* cursors, ScevAddRec* addRec);
bool TryReplaceUsesWithNewPrimaryIV(ArrayStack<CursorInfo>* cursors, ScevAddRec* iv);
BasicBlock* FindUpdateInsertionPoint(ArrayStack<CursorInfo>* cursors);
BasicBlock* FindUpdateInsertionPoint(ArrayStack<CursorInfo>* cursors, Statement** afterStmt);
BasicBlock* FindPostUseUpdateInsertionPoint(ArrayStack<CursorInfo>* cursors,
BasicBlock* backEdgeDominator,
Statement** afterStmt);
bool InsertionPointPostDominatesUses(BasicBlock* insertionPoint, ArrayStack<CursorInfo>* cursors);

bool StressProfitability()
{
Expand Down Expand Up @@ -2000,7 +2004,8 @@ bool StrengthReductionContext::TryReplaceUsesWithNewPrimaryIV(ArrayStack<CursorI
return false;
}

BasicBlock* insertionPoint = FindUpdateInsertionPoint(cursors);
Statement* afterStmt;
BasicBlock* insertionPoint = FindUpdateInsertionPoint(cursors, &afterStmt);
if (insertionPoint == nullptr)
{
JITDUMP(" Skipping: could not find a legal insertion point for the new IV update\n");
Expand Down Expand Up @@ -2032,7 +2037,14 @@ bool StrengthReductionContext::TryReplaceUsesWithNewPrimaryIV(ArrayStack<CursorI
m_comp->gtNewOperNode(GT_ADD, iv->Type, m_comp->gtNewLclVarNode(newPrimaryIV, iv->Type), stepValue);
GenTree* stepStore = m_comp->gtNewTempStore(newPrimaryIV, nextValue);
Statement* stepStmt = m_comp->fgNewStmtFromTree(stepStore);
m_comp->fgInsertStmtNearEnd(insertionPoint, stepStmt);
if (afterStmt != nullptr)
{
m_comp->fgInsertStmtAfter(insertionPoint, afterStmt, stepStmt);
}
else
{
m_comp->fgInsertStmtNearEnd(insertionPoint, stepStmt);
}

JITDUMP(" Inserting step statement in " FMT_BB "\n", insertionPoint->bbNum);
DISPSTMT(stepStmt);
Expand Down Expand Up @@ -2084,22 +2096,27 @@ bool StrengthReductionContext::TryReplaceUsesWithNewPrimaryIV(ArrayStack<CursorI
// of a new primary IV introduced by strength reduction.
//
// Parameters:
// cursors - The list of cursors pointing to uses that are being replaced by
// the new IV
// cursors - The list of cursors pointing to uses that are being replaced by
// the new IV
// afterStmt - [out] Statement to insert the update after. Set to nullptr if
// update should be inserted near the end of the block.
//
// Returns:
// Basic block; the insertion point is the end (before a potential
// terminator) of this basic block. May return null if no insertion point
// could be found.
//
BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStack<CursorInfo>* cursors)
BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStack<CursorInfo>* cursors, Statement** afterStmt)
{
*afterStmt = nullptr;

// Find insertion point. It needs to post-dominate all uses we are going to
// replace and it needs to dominate all backedges.
// TODO-CQ: Canonicalizing backedges would make this simpler and work in
// more cases.

BasicBlock* insertionPoint = nullptr;

for (FlowEdge* backEdge : m_loop->BackEdges())
{
if (insertionPoint == nullptr)
Expand All @@ -2112,6 +2129,18 @@ BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStack<Cursor
}
}

#ifdef TARGET_ARM64
// For arm64 we try to place the IV update after a use if possible. This
// sets the backend up for post-indexed addressing mode.
BasicBlock* postUseInsertionPoint = FindPostUseUpdateInsertionPoint(cursors, insertionPoint, afterStmt);
if (postUseInsertionPoint != nullptr)
{
JITDUMP(" Found a legal insertion point after a last use of the IV in " FMT_BB " after " FMT_STMT "\n",
postUseInsertionPoint->bbNum, (*afterStmt)->GetID());
return postUseInsertionPoint;
}
#endif

while ((insertionPoint != nullptr) && m_loop->ContainsBlock(insertionPoint) &&
m_loop->MayExecuteBlockMultipleTimesPerIteration(insertionPoint))
{
Expand All @@ -2123,6 +2152,124 @@ BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStack<Cursor
return nullptr;
}

if (!InsertionPointPostDominatesUses(insertionPoint, cursors))
{
return nullptr;
}

JITDUMP(" Found a legal insertion point in " FMT_BB "\n", insertionPoint->bbNum);
return insertionPoint;
}

//------------------------------------------------------------------------
// FindPostUseUpdateInsertionPoint: Try finding an insertion point for the IV
// update that is right after one of the uses of it.
//
// Parameters:
// cursors - The list of cursors pointing to uses that are being replaced by
// the new IV
// backEdgeDominator - A basic block that dominates all backedges
// afterStmt - [out] Statement to insert the update after, if the
// return value is non-null.
//
// Returns:
// nullptr if no such insertion point could be found. Otherwise returns the
// basic block and statement after which the update can be inserted.
//
BasicBlock* StrengthReductionContext::FindPostUseUpdateInsertionPoint(ArrayStack<CursorInfo>* cursors,
BasicBlock* backEdgeDominator,
Statement** afterStmt)
{
BitVecTraits poTraits = m_loop->GetDfsTree()->PostOrderTraits();

#ifdef DEBUG
// We will be relying on the fact that the cursors are ordered in a useful
// way here: loop locals are visited in post order within each basic block,
// meaning that "cursors" has the last uses first for each basic block.
// Assert that here.

BitVec seenBlocks(BitVecOps::MakeEmpty(&poTraits));
for (int i = 1; i < cursors->Height(); i++)
{
CursorInfo& prevCursor = cursors->BottomRef(i - 1);
CursorInfo& cursor = cursors->BottomRef(i);

if (cursor.Block != prevCursor.Block)
{
assert(BitVecOps::TryAddElemD(&poTraits, seenBlocks, prevCursor.Block->bbPostorderNum));
continue;
}

Statement* curStmt = cursor.Stmt;
while ((curStmt != nullptr) && (curStmt != prevCursor.Stmt))
{
curStmt = curStmt->GetNextStmt();
}

assert(curStmt == prevCursor.Stmt);
}
#endif

BitVec blocksWithUses(BitVecOps::MakeEmpty(&poTraits));
for (int i = 0; i < cursors->Height(); i++)
{
CursorInfo& cursor = cursors->BottomRef(i);
BitVecOps::AddElemD(&poTraits, blocksWithUses, cursor.Block->bbPostorderNum);
}

while ((backEdgeDominator != nullptr) && m_loop->ContainsBlock(backEdgeDominator))
{
if (!BitVecOps::IsMember(&poTraits, blocksWithUses, backEdgeDominator->bbPostorderNum))
{
backEdgeDominator = backEdgeDominator->bbIDom;
continue;
}

if (m_loop->MayExecuteBlockMultipleTimesPerIteration(backEdgeDominator))
{
return nullptr;
}

for (int i = 0; i < cursors->Height(); i++)
{
CursorInfo& cursor = cursors->BottomRef(i);
if (cursor.Block != backEdgeDominator)
{
continue;
}

if (!InsertionPointPostDominatesUses(cursor.Block, cursors))
{
return nullptr;
}

*afterStmt = cursor.Stmt;
return cursor.Block;
}
}

return nullptr;
}

//------------------------------------------------------------------------
// InsertionPointPostDominatesUses: Check if a basic block post-dominates all
// locations specified by the cursors.
//
// Parameters:
// insertionPoint - The insertion point
// cursors - Cursors specifying locations
//
// Returns:
// True if so.
//
// Remarks:
// For cursors inside "insertionPoint", the function expects that the
// insertion point is _after_ the use, except if the use is in a terminator
// statement.
//
bool StrengthReductionContext::InsertionPointPostDominatesUses(BasicBlock* insertionPoint,
ArrayStack<CursorInfo>* cursors)
{
for (int i = 0; i < cursors->Height(); i++)
{
CursorInfo& cursor = cursors->BottomRef(i);
Expand All @@ -2131,19 +2278,19 @@ BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStack<Cursor
{
if (insertionPoint->HasTerminator() && (cursor.Stmt == insertionPoint->lastStmt()))
{
return nullptr;
return false;
}
}
else
{
if (!m_loop->IsPostDominatedOnLoopIteration(cursor.Block, insertionPoint))
{
return nullptr;
return false;
}
}
}

return insertionPoint;
return true;
}

//------------------------------------------------------------------------
Expand Down
17 changes: 10 additions & 7 deletions src/coreclr/jit/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,7 @@ GenTree* Lowering::LowerNode(GenTree* node)
FALLTHROUGH;

case GT_STORE_LCL_FLD:
LowerStoreLocCommon(node->AsLclVarCommon());
break;
return LowerStoreLocCommon(node->AsLclVarCommon());

#if defined(TARGET_ARM64) || defined(TARGET_LOONGARCH64) || defined(TARGET_RISCV64)
case GT_CMPXCHG:
Expand Down Expand Up @@ -4783,7 +4782,10 @@ void Lowering::LowerRet(GenTreeOp* ret)
// Arguments:
// lclStore - The store lcl node to lower.
//
void Lowering::LowerStoreLocCommon(GenTreeLclVarCommon* lclStore)
// Returns:
// Next node to lower.
//
GenTree* Lowering::LowerStoreLocCommon(GenTreeLclVarCommon* lclStore)
{
assert(lclStore->OperIs(GT_STORE_LCL_FLD, GT_STORE_LCL_VAR));
JITDUMP("lowering store lcl var/field (before):\n");
Expand Down Expand Up @@ -4870,8 +4872,7 @@ void Lowering::LowerStoreLocCommon(GenTreeLclVarCommon* lclStore)
lclStore->gtOp1 = spilledCall;
src = lclStore->gtOp1;
JITDUMP("lowering store lcl var/field has to spill call src.\n");
LowerStoreLocCommon(lclStore);
return;
return LowerStoreLocCommon(lclStore);
}
#endif // !WINDOWS_AMD64_ABI
convertToStoreObj = false;
Expand Down Expand Up @@ -4966,7 +4967,7 @@ void Lowering::LowerStoreLocCommon(GenTreeLclVarCommon* lclStore)
DISPTREERANGE(BlockRange(), objStore);
JITDUMP("\n");

return;
return objStore->gtNext;
}
}

Expand All @@ -4984,11 +4985,13 @@ void Lowering::LowerStoreLocCommon(GenTreeLclVarCommon* lclStore)
ContainCheckBitCast(bitcast);
}

LowerStoreLoc(lclStore);
GenTree* next = LowerStoreLoc(lclStore);

JITDUMP("lowering store lcl var/field (after):\n");
DISPTREERANGE(BlockRange(), lclStore);
JITDUMP("\n");

return next;
}

//----------------------------------------------------------------------------------------------
Expand Down
14 changes: 8 additions & 6 deletions src/coreclr/jit/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class Lowering final : public Phase
GenTreeCC* LowerNodeCC(GenTree* node, GenCondition condition);
void LowerJmpMethod(GenTree* jmp);
void LowerRet(GenTreeOp* ret);
void LowerStoreLocCommon(GenTreeLclVarCommon* lclVar);
GenTree* LowerStoreLocCommon(GenTreeLclVarCommon* lclVar);
void LowerRetStruct(GenTreeUnOp* ret);
void LowerRetSingleRegStructLclVar(GenTreeUnOp* ret);
void LowerCallStruct(GenTreeCall* call);
Expand Down Expand Up @@ -353,6 +353,8 @@ class Lowering final : public Phase
GenTree* LowerIndir(GenTreeIndir* ind);
bool OptimizeForLdpStp(GenTreeIndir* ind);
bool TryMakeIndirsAdjacent(GenTreeIndir* prevIndir, GenTreeIndir* indir);
bool TryMoveAddSubRMWAfterIndir(GenTreeLclVarCommon* store);
bool TryMakeIndirAndStoreAdjacent(GenTreeIndir* prevIndir, GenTreeLclVarCommon* store);
void MarkTree(GenTree* root);
void UnmarkTree(GenTree* root);
GenTree* LowerStoreIndir(GenTreeStoreInd* node);
Expand Down Expand Up @@ -401,11 +403,11 @@ class Lowering final : public Phase
bool LowerRMWMemOp(GenTreeIndir* storeInd);
#endif

void WidenSIMD12IfNecessary(GenTreeLclVarCommon* node);
bool CheckMultiRegLclVar(GenTreeLclVar* lclNode, int registerCount);
void LowerStoreLoc(GenTreeLclVarCommon* tree);
void LowerRotate(GenTree* tree);
void LowerShift(GenTreeOp* shift);
void WidenSIMD12IfNecessary(GenTreeLclVarCommon* node);
bool CheckMultiRegLclVar(GenTreeLclVar* lclNode, int registerCount);
GenTree* LowerStoreLoc(GenTreeLclVarCommon* tree);
void LowerRotate(GenTree* tree);
void LowerShift(GenTreeOp* shift);
#ifdef FEATURE_HW_INTRINSICS
GenTree* LowerHWIntrinsic(GenTreeHWIntrinsic* node);
void LowerHWIntrinsicCC(GenTreeHWIntrinsic* node, NamedIntrinsic newIntrinsicId, GenCondition condition);
Expand Down
Loading
Loading