Skip to content

Commit

Permalink
Handle more than 64 registers - Part 1 (#101950)
Browse files Browse the repository at this point in the history
* Convert regMaskTP for ARM64 to struct with single field

* Fix genFirstRegNumFromMaskAndToggle() and genFirstRegNumFromMask()

* minor fix

* review feedback

* fix the TP regression from 1.5% -> 0.5%

* Pass by value

* jit format

* review feedback

* Remove FORCEINLINE

* Remove setLow()
  • Loading branch information
kunalspathak authored May 11, 2024
1 parent 34e65b9 commit fdc9c9d
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 32 deletions.
9 changes: 4 additions & 5 deletions src/coreclr/jit/codegencommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3829,9 +3829,9 @@ void CodeGen::genZeroInitFltRegs(const regMaskTP& initFltRegs, const regMaskTP&

// Iterate through float/double registers and initialize them to 0 or
// copy from already initialized register of the same type.
regMaskTP regMask = genRegMask(REG_FP_FIRST);
for (regNumber reg = REG_FP_FIRST; reg <= REG_FP_LAST; reg = REG_NEXT(reg), regMask <<= 1)
for (regNumber reg = REG_FP_FIRST; reg <= REG_FP_LAST; reg = REG_NEXT(reg))
{
regMaskTP regMask = genRegMask(reg);
if (regMask & initFltRegs)
{
// Do we have a float register already set to 0?
Expand Down Expand Up @@ -5732,10 +5732,9 @@ void CodeGen::genFnProlog()

if (initRegs)
{
regMaskTP regMask = 0x1;

for (regNumber reg = REG_INT_FIRST; reg <= REG_INT_LAST; reg = REG_NEXT(reg), regMask <<= 1)
for (regNumber reg = REG_INT_FIRST; reg <= REG_INT_LAST; reg = REG_NEXT(reg))
{
regMaskTP regMask = genRegMask(reg);
if (regMask & initRegs)
{
// Check if we have already zeroed this register
Expand Down
64 changes: 61 additions & 3 deletions src/coreclr/jit/compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,33 @@ inline bool genExactlyOneBit(T value)
return ((value != 0) && genMaxOneBit(value));
}

#ifdef TARGET_ARM64
inline regMaskTP genFindLowestBit(regMaskTP value)
{
return regMaskTP(genFindLowestBit(value.getLow()));
}

/*****************************************************************************
*
* Return true if the given value has exactly zero or one bits set.
*/

inline bool genMaxOneBit(regMaskTP value)
{
return genMaxOneBit(value.getLow());
}

/*****************************************************************************
*
* Return true if the given value has exactly one bit set.
*/

inline bool genExactlyOneBit(regMaskTP value)
{
return genExactlyOneBit(value.getLow());
}
#endif

/*****************************************************************************
*
* Given a value that has exactly one bit set, return the position of that
Expand Down Expand Up @@ -147,6 +174,13 @@ inline unsigned genCountBits(uint64_t bits)
return BitOperations::PopCount(bits);
}

#ifdef TARGET_ARM64
inline unsigned genCountBits(regMaskTP mask)
{
return BitOperations::PopCount(mask.getLow());
}
#endif

/*****************************************************************************
*
* A rather simple routine that counts the number of bits in a given number.
Expand Down Expand Up @@ -914,11 +948,18 @@ inline regNumber genRegNumFromMask(regMaskTP mask)

/* Convert the mask to a register number */

regNumber regNum = (regNumber)genLog2(mask);
#ifdef TARGET_ARM64
regNumber regNum = (regNumber)genLog2(mask.getLow());

/* Make sure we got it right */
assert(genRegMask(regNum) == mask.getLow());

#else
regNumber regNum = (regNumber)genLog2(mask);

/* Make sure we got it right */
assert(genRegMask(regNum) == mask);
#endif

return regNum;
}
Expand All @@ -940,7 +981,8 @@ inline regNumber genFirstRegNumFromMaskAndToggle(regMaskTP& mask)

/* Convert the mask to a register number */

regNumber regNum = (regNumber)BitOperations::BitScanForward(mask);
regNumber regNum = (regNumber)BitScanForward(mask);

mask ^= genRegMask(regNum);

return regNum;
Expand All @@ -962,7 +1004,7 @@ inline regNumber genFirstRegNumFromMask(regMaskTP mask)

/* Convert the mask to a register number */

regNumber regNum = (regNumber)BitOperations::BitScanForward(mask);
regNumber regNum = (regNumber)BitScanForward(mask);

return regNum;
}
Expand Down Expand Up @@ -4463,30 +4505,46 @@ inline void* operator new[](size_t sz, Compiler* compiler, CompMemKind cmk)

inline void printRegMask(regMaskTP mask)
{
#ifdef TARGET_ARM64
printf(REG_MASK_ALL_FMT, mask.getLow());
#else
printf(REG_MASK_ALL_FMT, mask);
#endif
}

inline char* regMaskToString(regMaskTP mask, Compiler* context)
{
const size_t cchRegMask = 24;
char* regmask = new (context, CMK_Unknown) char[cchRegMask];

#ifdef TARGET_ARM64
sprintf_s(regmask, cchRegMask, REG_MASK_ALL_FMT, mask.getLow());
#else
sprintf_s(regmask, cchRegMask, REG_MASK_ALL_FMT, mask);
#endif

return regmask;
}

inline void printRegMaskInt(regMaskTP mask)
{
#ifdef TARGET_ARM64
printf(REG_MASK_INT_FMT, (mask & RBM_ALLINT).getLow());
#else
printf(REG_MASK_INT_FMT, (mask & RBM_ALLINT));
#endif
}

inline char* regMaskIntToString(regMaskTP mask, Compiler* context)
{
const size_t cchRegMask = 24;
char* regmask = new (context, CMK_Unknown) char[cchRegMask];

#ifdef TARGET_ARM64
sprintf_s(regmask, cchRegMask, REG_MASK_INT_FMT, (mask & RBM_ALLINT).getLow());
#else
sprintf_s(regmask, cchRegMask, REG_MASK_INT_FMT, (mask & RBM_ALLINT));
#endif

return regmask;
}
Expand Down
8 changes: 4 additions & 4 deletions src/coreclr/jit/emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3510,7 +3510,7 @@ void emitter::emitDispRegSet(regMaskTP regs)
continue;
}

regs -= curReg;
regs ^= curReg;

if (sp)
{
Expand Down Expand Up @@ -3870,8 +3870,8 @@ void emitter::emitDispGCRegDelta(const char* title, regMaskTP prevRegs, regMaskT
{
emitDispGCDeltaTitle(title);
regMaskTP sameRegs = prevRegs & curRegs;
regMaskTP removedRegs = prevRegs - sameRegs;
regMaskTP addedRegs = curRegs - sameRegs;
regMaskTP removedRegs = prevRegs ^ sameRegs;
regMaskTP addedRegs = curRegs ^ sameRegs;
if (removedRegs != RBM_NONE)
{
printf(" -");
Expand Down Expand Up @@ -8972,7 +8972,7 @@ void emitter::emitUpdateLiveGCregs(GCtype gcType, regMaskTP regs, BYTE* addr)
emitGCregDeadUpd(reg, addr);
}

chg -= bit;
chg ^= bit;
} while (chg);

assert(emitThisXXrefRegs == regs);
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/gcencode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4667,7 +4667,7 @@ void GCInfo::gcInfoRecordGCRegStateChange(GcInfoEncoder* gcInfoEncoder,
}

// Turn the bit we've just generated off and continue.
regMask -= tmpMask; // EAX,ECX,EDX,EBX,---,EBP,ESI,EDI
regMask ^= tmpMask; // EAX,ECX,EDX,EBX,---,EBP,ESI,EDI
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/lsra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13614,7 +13614,7 @@ regMaskTP LinearScan::RegisterSelection::select(Interval* current
&overallLimitCandidates);
assert(limitConsecutiveResult != RBM_NONE);

unsigned startRegister = BitOperations::BitScanForward(limitConsecutiveResult);
unsigned startRegister = BitScanForward(limitConsecutiveResult);

regMaskTP registersNeededMask = (1ULL << refPosition->regCount) - 1;
candidates |= (registersNeededMask << startRegister);
Expand Down
4 changes: 2 additions & 2 deletions src/coreclr/jit/lsra.h
Original file line number Diff line number Diff line change
Expand Up @@ -794,8 +794,8 @@ class LinearScan : public LinearScanInterface
static const regMaskTP LsraLimitSmallIntSet = (RBM_R0 | RBM_R1 | RBM_R2 | RBM_R3 | RBM_R4 | RBM_R5);
static const regMaskTP LsraLimitSmallFPSet = (RBM_F0 | RBM_F1 | RBM_F2 | RBM_F16 | RBM_F17);
#elif defined(TARGET_ARM64)
static const regMaskTP LsraLimitSmallIntSet = (RBM_R0 | RBM_R1 | RBM_R2 | RBM_R19 | RBM_R20);
static const regMaskTP LsraLimitSmallFPSet = (RBM_V0 | RBM_V1 | RBM_V2 | RBM_V8 | RBM_V9);
static constexpr regMaskTP LsraLimitSmallIntSet = (RBM_R0 | RBM_R1 | RBM_R2 | RBM_R19 | RBM_R20);
static constexpr regMaskTP LsraLimitSmallFPSet = (RBM_V0 | RBM_V1 | RBM_V2 | RBM_V8 | RBM_V9);
#elif defined(TARGET_X86)
static const regMaskTP LsraLimitSmallIntSet = (RBM_EAX | RBM_ECX | RBM_EDI);
static const regMaskTP LsraLimitSmallFPSet = (RBM_XMM0 | RBM_XMM1 | RBM_XMM2 | RBM_XMM6 | RBM_XMM7);
Expand Down
10 changes: 5 additions & 5 deletions src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ regMaskTP LinearScan::filterConsecutiveCandidates(regMaskTP candidates,
unsigned int registersNeeded,
regMaskTP* allConsecutiveCandidates)
{
if (BitOperations::PopCount(candidates) < registersNeeded)
if (PopCount(candidates) < registersNeeded)
{
// There is no way the register demanded can be satisfied for this RefPosition
// based on the candidates from which it can allocate a register.
Expand All @@ -205,7 +205,7 @@ regMaskTP LinearScan::filterConsecutiveCandidates(regMaskTP candidates,
do
{
// From LSB, find the first available register (bit `1`)
regAvailableStartIndex = BitOperations::BitScanForward(static_cast<DWORD64>(currAvailableRegs));
regAvailableStartIndex = BitScanForward(currAvailableRegs);
regMaskTP startMask = (1ULL << regAvailableStartIndex) - 1;

// Mask all the bits that are processed from LSB thru regAvailableStart until the last `1`.
Expand All @@ -223,7 +223,7 @@ regMaskTP LinearScan::filterConsecutiveCandidates(regMaskTP candidates,
}
else
{
regAvailableEndIndex = BitOperations::BitScanForward(static_cast<DWORD64>(maskProcessed));
regAvailableEndIndex = BitScanForward(maskProcessed);
}
regMaskTP endMask = (1ULL << regAvailableEndIndex) - 1;

Expand Down Expand Up @@ -335,7 +335,7 @@ regMaskTP LinearScan::filterConsecutiveCandidatesForSpill(regMaskTP consecutiveC
do
{
// From LSB, find the first available register (bit `1`)
regAvailableStartIndex = BitOperations::BitScanForward(static_cast<DWORD64>(unprocessedRegs));
regAvailableStartIndex = BitScanForward(unprocessedRegs);

// For the current range, find how many registers are free vs. busy
regMaskTP maskForCurRange = RBM_NONE;
Expand Down Expand Up @@ -370,7 +370,7 @@ regMaskTP LinearScan::filterConsecutiveCandidatesForSpill(regMaskTP consecutiveC
// In the given range, there are some free registers available. Calculate how many registers
// will need spilling if this range is picked.

int curSpillRegs = registersNeeded - BitOperations::PopCount(maskForCurRange);
int curSpillRegs = registersNeeded - PopCount(maskForCurRange);
if (curSpillRegs < maxSpillRegs)
{
consecutiveResultForBusy = 1ULL << regAvailableStartIndex;
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/lsrabuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2898,7 +2898,7 @@ void LinearScan::stressSetRandomParameterPreferences()

// Select a random register from all possible parameter registers
// (of the right type). Preference this parameter to that register.
unsigned numBits = BitOperations::PopCount(*regs);
unsigned numBits = PopCount(*regs);
if (numBits == 0)
{
continue;
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/regset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ regMaskSmall genRegMaskFromCalleeSavedMask(unsigned short calleeSaveMask)
regMaskSmall res = 0;
for (int i = 0; i < CNT_CALL_GC_REGS; i++)
{
if ((calleeSaveMask & ((regMaskTP)1 << i)) != 0)
if ((calleeSaveMask & (1 << i)) != 0)
{
res |= raRbmCalleeSaveOrder[i];
}
Expand Down
Loading

0 comments on commit fdc9c9d

Please sign in to comment.