Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARM64: Fix lsra for AdvSimd_LoadAndInsertScalar #107786

Merged
merged 9 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ enum HWIntrinsicFlag : unsigned int
// The intrinsic uses a mask in arg1 to select elements present in the result
HW_Flag_ExplicitMaskedOperation = 0x20000,

// The intrinsic uses a mask in arg1 (either explicitly, embdedd or optionally embedded) to select elements present
// The intrinsic uses a mask in arg1 (either explicitly, embedded or optionally embedded) to select elements present
// in the result, and must use a low register.
HW_Flag_LowMaskedOperation = 0x40000,

Expand Down
9 changes: 7 additions & 2 deletions src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,7 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
{
assert(isRMW);
assert(intrin.op1->OperIs(GT_FIELD_LIST));

GenTreeFieldList* op1 = intrin.op1->AsFieldList();
assert(compiler->info.compNeedsConsecutiveRegisters);

Expand Down Expand Up @@ -1724,7 +1725,7 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
}
}
}
else if (intrinsicTree->OperIsMemoryLoadOrStore())
else if ((intrinsicTree->OperIsMemoryLoadOrStore()) && (intrin.id != NI_AdvSimd_LoadAndInsertScalar))
{
srcCount += BuildAddrUses(intrin.op1);
}
Expand Down Expand Up @@ -2151,7 +2152,11 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
{
SingleTypeRegSet candidates = lowVectorOperandNum == 3 ? lowVectorCandidates : RBM_NONE;

if (isRMW)
if (intrin.id == NI_AdvSimd_LoadAndInsertScalar)
{
srcCount += BuildAddrUses(intrin.op3);
}
else if (isRMW)
{
srcCount += BuildDelayFreeUses(intrin.op3, (tgtPrefOp2 ? intrin.op2 : intrin.op1), candidates);
}
Expand Down
31 changes: 26 additions & 5 deletions src/coreclr/jit/lsrabuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3846,9 +3846,25 @@ int LinearScan::BuildDelayFreeUses(GenTree* node,
return 0;
}
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kunalspathak : As suggested offline, added some additional checks in the delay slot logic. If the register type of the register does not match that of the delay slot register then do not add the delay free use. This fixes where we call BuildDelayFreeUses() for op2 which is a integer value.

Instead, we could do these checks in BuildHWIntrinsic() to not call BuildDelayFreeUses(), but then BuildDelayFreeUses() would still need the same checks turned into asserts. So it seemed simpler to do inside.

This will work nicely with BuildHWIntrinsic() rewrite PR too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this should be assert and the callers should not be calling delay free uses for different register type. Did you turned it into an assert and see how many places are hit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Difficult to test with an assert because 1) all the hwintrinsic tests will fail at the first assert 2) many of these issues will be due during ConstantExpected APIs, and for a lot of those we are only testing using hardcoded constants, meaning the assert will never be hit.

Scanning the SVE API, all the methods we have that are RMW and have a ConstantExpected are:

AddRotateComplex
DotProductBySelectedScalar
ExtractVector
FusedMultiplyAddBySelectedScalar
FusedMultiplySubtractBySelectedScalar
MultiplyAddRotateComplex
MultiplyAddRotateComplexBySelectedScalar
MultiplyBySelectedScalar
SaturatingDecrementBy16BitElementCount
SaturatingDecrementBy32BitElementCount
SaturatingDecrementBy64BitElementCount
SaturatingDecrementBy8BitElementCount
SaturatingDecrementByActiveElementCount
SaturatingIncrementBy16BitElementCount
SaturatingIncrementBy32BitElementCount
SaturatingIncrementBy64BitElementCount
SaturatingIncrementBy8BitElementCount
SaturatingIncrementByActiveElementCount
ShiftRightArithmeticForDivide
TrigonometricMultiplyAddCoefficient

Then there are AdvSimd ones. Then there are possibly others.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much easier to fix properly in the lsra rewrite PR, as the check can just be added into the main for loop in BuildHWIntrinsic()

        else if (delayFreeOp != nullptr && TheExtraChecksFromThisPR.....)
        {
            srcCount += BuildDelayFreeUses(operand, delayFreeOp, candidates);
        }
        else
        {
            srcCount += BuildOperandUses(operand, candidates);
        }

// Don't mark as delay free if there is a mismatch in register types
bool addDelayFreeUses = false;
// Multi register nodes should not go via this route.
assert(!node->IsMultiRegNode());
// Multi register nodes should always use fp registers (this includes vectors).
assert(varTypeUsesFloatReg(node->TypeGet()) || !node->IsMultiRegNode());
if (rmwNode == nullptr || varTypeUsesSameRegType(rmwNode->TypeGet(), node->TypeGet()) ||
(rmwNode->IsMultiRegNode() && varTypeUsesFloatReg(node->TypeGet())))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did not understand the && varTypeUsesFloatReg(node->TypeGet()) here? did you see code where having delay free is ok to have for multi-reg node that are GPR (if there is such a thing)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the RMW node is a multi register, then it'll always be vector registers.

The normal node could be general register or vector register. We want to discount the case where node is a general register. So, varTypeUsesFloatReg() is confirming node uses a FP, vector or sve vector register

Eg:

(Vector128<sbyte> Value1, Vector128<sbyte> Value2) LoadAndInsertScalar((Vector128<sbyte>, Vector128<sbyte>) values, [ConstantExpected(Max = (byte)(15))] byte index, sbyte* address)

The RMW node is the multi register values.
This code makes sure index is not delay slotted.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering we should have an assert that if it is a multiRegNode, then node should be floatreg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some asserts.

That then broke things because for delayFreeMultiple intrinsics we do:

BuildDelayFreeUses(use.GetNode(), intrinsicTree);

Which is passing the entire intrinsicTree down as the rmw node which feels wrong. Fixed that to pass op1 instead.

{
addDelayFreeUses = true;
}

if (use != nullptr)
{
AddDelayFreeUses(use, rmwNode);
if (addDelayFreeUses)
{
AddDelayFreeUses(use, rmwNode);
}
if (useRefPositionRef != nullptr)
{
*useRefPositionRef = use;
Expand All @@ -3864,15 +3880,20 @@ int LinearScan::BuildDelayFreeUses(GenTree* node,
if (addrMode->HasBase() && !addrMode->Base()->isContained())
{
use = BuildUse(addrMode->Base(), candidates);
AddDelayFreeUses(use, rmwNode);

if (addDelayFreeUses)
{
AddDelayFreeUses(use, rmwNode);
}
srcCount++;
}

if (addrMode->HasIndex() && !addrMode->Index()->isContained())
{
use = BuildUse(addrMode->Index(), candidates);
AddDelayFreeUses(use, rmwNode);

if (addDelayFreeUses)
{
AddDelayFreeUses(use, rmwNode);
}
srcCount++;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ namespace JIT.HardwareIntrinsics.Arm
// Validates passing an instance member of a class works
test.RunClassFldScenario();

// Validates passing an non const value works
test.RunClassFldScenario_NotConstant();

// Validates passing the field of a local struct works
test.RunStructLclFldScenario();

Expand Down Expand Up @@ -150,6 +153,7 @@ namespace JIT.HardwareIntrinsics.Arm
private static {Op1BaseType}[] _data1 = new {Op1BaseType}[Op1ElementCount];

private {Op1VectorType}<{Op1BaseType}> _fld1;
private byte _fld2;
private {Op1BaseType} _fld3;

private DataTable _dataTable;
Expand All @@ -161,6 +165,7 @@ namespace JIT.HardwareIntrinsics.Arm
for (var i = 0; i < Op1ElementCount; i++) { _data1[i] = {NextValueOp1}; }
Unsafe.CopyBlockUnaligned(ref Unsafe.As<{Op1VectorType}<{Op1BaseType}>, byte>(ref _fld1), ref Unsafe.As<{Op1BaseType}, byte>(ref _data1[0]), (uint)Unsafe.SizeOf<{Op1VectorType}<{Op1BaseType}>>());

_fld2 = {ElementIndex};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lot of the tests that takes immediate value is missing this coverage. We should fix it some day.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an issue: #108060

_fld3 = {NextValueOp3};

for (var i = 0; i < Op1ElementCount; i++) { _data1[i] = {NextValueOp1}; }
Expand Down Expand Up @@ -247,6 +252,20 @@ namespace JIT.HardwareIntrinsics.Arm
ValidateResult(_fld1, _fld3, _dataTable.outArrayPtr);
}

public void RunClassFldScenario_NotConstant()
{
TestLibrary.TestFramework.BeginScenario(nameof(RunClassFldScenario_NotConstant));

fixed ({Op1BaseType}* pFld3 = &_fld3)
{
var result = {Isa}.{Method}(_fld1, _fld2, pFld3);

Unsafe.Write(_dataTable.outArrayPtr, result);
}

ValidateResult(_fld1, _fld3, _dataTable.outArrayPtr);
}

public void RunStructLclFldScenario()
{
TestLibrary.TestFramework.BeginScenario(nameof(RunStructLclFldScenario));
Expand Down
Loading