diff --git a/src/coreclr/jit/inductionvariableopts.cpp b/src/coreclr/jit/inductionvariableopts.cpp index fa3d627feadcc8..990100ec8e3bbc 100644 --- a/src/coreclr/jit/inductionvariableopts.cpp +++ b/src/coreclr/jit/inductionvariableopts.cpp @@ -1341,7 +1341,11 @@ class StrengthReductionContext bool CheckAdvancedCursors(ArrayStack* cursors, ScevAddRec** nextIV); bool StaysWithinManagedObject(ArrayStack* cursors, ScevAddRec* addRec); bool TryReplaceUsesWithNewPrimaryIV(ArrayStack* cursors, ScevAddRec* iv); - BasicBlock* FindUpdateInsertionPoint(ArrayStack* cursors); + BasicBlock* FindUpdateInsertionPoint(ArrayStack* cursors, Statement** afterStmt); + BasicBlock* FindPostUseUpdateInsertionPoint(ArrayStack* cursors, + BasicBlock* backEdgeDominator, + Statement** afterStmt); + bool InsertionPointPostDominatesUses(BasicBlock* insertionPoint, ArrayStack* cursors); bool StressProfitability() { @@ -2000,7 +2004,8 @@ bool StrengthReductionContext::TryReplaceUsesWithNewPrimaryIV(ArrayStackgtNewOperNode(GT_ADD, iv->Type, m_comp->gtNewLclVarNode(newPrimaryIV, iv->Type), stepValue); GenTree* stepStore = m_comp->gtNewTempStore(newPrimaryIV, nextValue); Statement* stepStmt = m_comp->fgNewStmtFromTree(stepStore); - m_comp->fgInsertStmtNearEnd(insertionPoint, stepStmt); + if (afterStmt != nullptr) + { + m_comp->fgInsertStmtAfter(insertionPoint, afterStmt, stepStmt); + } + else + { + m_comp->fgInsertStmtNearEnd(insertionPoint, stepStmt); + } JITDUMP(" Inserting step statement in " FMT_BB "\n", insertionPoint->bbNum); DISPSTMT(stepStmt); @@ -2084,22 +2096,27 @@ bool StrengthReductionContext::TryReplaceUsesWithNewPrimaryIV(ArrayStack* cursors) +BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStack* cursors, Statement** afterStmt) { + *afterStmt = nullptr; + // Find insertion point. It needs to post-dominate all uses we are going to // replace and it needs to dominate all backedges. // TODO-CQ: Canonicalizing backedges would make this simpler and work in // more cases. BasicBlock* insertionPoint = nullptr; + for (FlowEdge* backEdge : m_loop->BackEdges()) { if (insertionPoint == nullptr) @@ -2112,6 +2129,18 @@ BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStackbbNum, (*afterStmt)->GetID()); + return postUseInsertionPoint; + } +#endif + while ((insertionPoint != nullptr) && m_loop->ContainsBlock(insertionPoint) && m_loop->MayExecuteBlockMultipleTimesPerIteration(insertionPoint)) { @@ -2123,6 +2152,124 @@ BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStackbbNum); + return insertionPoint; +} + +//------------------------------------------------------------------------ +// FindPostUseUpdateInsertionPoint: Try finding an insertion point for the IV +// update that is right after one of the uses of it. +// +// Parameters: +// cursors - The list of cursors pointing to uses that are being replaced by +// the new IV +// backEdgeDominator - A basic block that dominates all backedges +// afterStmt - [out] Statement to insert the update after, if the +// return value is non-null. +// +// Returns: +// nullptr if no such insertion point could be found. Otherwise returns the +// basic block and statement after which the update can be inserted. +// +BasicBlock* StrengthReductionContext::FindPostUseUpdateInsertionPoint(ArrayStack* cursors, + BasicBlock* backEdgeDominator, + Statement** afterStmt) +{ + BitVecTraits poTraits = m_loop->GetDfsTree()->PostOrderTraits(); + +#ifdef DEBUG + // We will be relying on the fact that the cursors are ordered in a useful + // way here: loop locals are visited in post order within each basic block, + // meaning that "cursors" has the last uses first for each basic block. + // Assert that here. + + BitVec seenBlocks(BitVecOps::MakeEmpty(&poTraits)); + for (int i = 1; i < cursors->Height(); i++) + { + CursorInfo& prevCursor = cursors->BottomRef(i - 1); + CursorInfo& cursor = cursors->BottomRef(i); + + if (cursor.Block != prevCursor.Block) + { + assert(BitVecOps::TryAddElemD(&poTraits, seenBlocks, prevCursor.Block->bbPostorderNum)); + continue; + } + + Statement* curStmt = cursor.Stmt; + while ((curStmt != nullptr) && (curStmt != prevCursor.Stmt)) + { + curStmt = curStmt->GetNextStmt(); + } + + assert(curStmt == prevCursor.Stmt); + } +#endif + + BitVec blocksWithUses(BitVecOps::MakeEmpty(&poTraits)); + for (int i = 0; i < cursors->Height(); i++) + { + CursorInfo& cursor = cursors->BottomRef(i); + BitVecOps::AddElemD(&poTraits, blocksWithUses, cursor.Block->bbPostorderNum); + } + + while ((backEdgeDominator != nullptr) && m_loop->ContainsBlock(backEdgeDominator)) + { + if (!BitVecOps::IsMember(&poTraits, blocksWithUses, backEdgeDominator->bbPostorderNum)) + { + backEdgeDominator = backEdgeDominator->bbIDom; + continue; + } + + if (m_loop->MayExecuteBlockMultipleTimesPerIteration(backEdgeDominator)) + { + return nullptr; + } + + for (int i = 0; i < cursors->Height(); i++) + { + CursorInfo& cursor = cursors->BottomRef(i); + if (cursor.Block != backEdgeDominator) + { + continue; + } + + if (!InsertionPointPostDominatesUses(cursor.Block, cursors)) + { + return nullptr; + } + + *afterStmt = cursor.Stmt; + return cursor.Block; + } + } + + return nullptr; +} + +//------------------------------------------------------------------------ +// InsertionPointPostDominatesUses: Check if a basic block post-dominates all +// locations specified by the cursors. +// +// Parameters: +// insertionPoint - The insertion point +// cursors - Cursors specifying locations +// +// Returns: +// True if so. +// +// Remarks: +// For cursors inside "insertionPoint", the function expects that the +// insertion point is _after_ the use, except if the use is in a terminator +// statement. +// +bool StrengthReductionContext::InsertionPointPostDominatesUses(BasicBlock* insertionPoint, + ArrayStack* cursors) +{ for (int i = 0; i < cursors->Height(); i++) { CursorInfo& cursor = cursors->BottomRef(i); @@ -2131,19 +2278,19 @@ BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStackHasTerminator() && (cursor.Stmt == insertionPoint->lastStmt())) { - return nullptr; + return false; } } else { if (!m_loop->IsPostDominatedOnLoopIteration(cursor.Block, insertionPoint)) { - return nullptr; + return false; } } } - return insertionPoint; + return true; } //------------------------------------------------------------------------