Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT ARM64-SVE: Add saturating decrement/increment by element count #102315

Merged
merged 32 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
51c328b
JIT ARM64-SVE: Add saturating decrement/increment by element count
a74nh May 3, 2024
286df97
Merge main
a74nh May 17, 2024
a9f6aaa
Add fallback for scalar variants
a74nh May 17, 2024
a40951f
Add fallback for vector variants
a74nh May 17, 2024
a43e870
HW_Flag_HasScalarVariant comment
a74nh May 20, 2024
e700111
Add IsCnsIntOrI asserts
a74nh May 20, 2024
a6cb5e3
Simpler SubtractSaturateScalar()
a74nh May 20, 2024
1e6082f
Combine codegen cases
a74nh May 20, 2024
84a5d22
Inline DecodePredicateCount()
a74nh May 20, 2024
4a9dc8a
Remove C# fallbacks
a74nh May 21, 2024
5025a8c
Add HWIntrinsicImmOpHelper variant
a74nh May 23, 2024
b79c21d
Add special import
a74nh May 28, 2024
ef74d46
Add out of range testing
a74nh May 28, 2024
999b5e4
Expand testing
a74nh May 28, 2024
e82dc4e
Add testing for different patterns
a74nh May 28, 2024
d021f18
Merge main
a74nh May 28, 2024
9780f7d
formatting
a74nh May 28, 2024
646f03c
Fix merge failure
a74nh May 28, 2024
9c4bec7
Remove entries of SaturatingIncrementByActiveElementCount()
kunalspathak May 28, 2024
cd779d5
HW_Flag_HasScalarInputVariant
a74nh May 29, 2024
c9569fc
review cleanups
a74nh May 29, 2024
38a5614
Fix comments from X86 to Arm
a74nh May 29, 2024
b66bab7
Add missing tests
a74nh May 29, 2024
4414b33
Add missing min/max for ConstantExpected
a74nh May 29, 2024
b034ea7
formatting
a74nh May 29, 2024
a536c01
Add isValidScalarIntrinsic check
a74nh May 29, 2024
1c11ce2
Add RMW asserts
a74nh May 30, 2024
6a3dfea
Move internal register use after dest def
a74nh May 31, 2024
5cbd267
Merge main
a74nh May 31, 2024
e7047ae
Use Saturating helpers in testing
a74nh May 31, 2024
975eb1e
Ensure Saturating helpers use the correct size
a74nh May 31, 2024
f6a232e
Set internal register as delay for Saturating
a74nh May 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/coreclr/jit/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,10 @@ class CodeGen final : public CodeGenInterface
class HWIntrinsicImmOpHelper final
{
public:
HWIntrinsicImmOpHelper(CodeGen* codeGen, GenTree* immOp, GenTreeHWIntrinsic* intrin, int immNum = 1);
HWIntrinsicImmOpHelper(CodeGen* codeGen, GenTree* immOp, GenTreeHWIntrinsic* intrin);

HWIntrinsicImmOpHelper(
CodeGen* codeGen, regNumber immReg, int immLowerBound, int immUpperBound, GenTreeHWIntrinsic* intrin);

void EmitBegin();
void EmitCaseEnd();
Expand Down Expand Up @@ -1040,6 +1043,7 @@ class CodeGen final : public CodeGenInterface
regNumber nonConstImmReg;
regNumber branchTargetReg;
};

#endif // TARGET_ARM64

#endif // FEATURE_HW_INTRINSICS
Expand Down
22 changes: 21 additions & 1 deletion src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,27 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
unsigned int sizeBytes;

simdBaseJitType = getBaseJitTypeAndSizeOfSIMDType(clsHnd, &sizeBytes);
assert((category == HW_Category_Special) || (category == HW_Category_Helper) || (sizeBytes != 0));

#if defined(TARGET_ARM64)
if (simdBaseJitType == CORINFO_TYPE_UNDEF && HWIntrinsicInfo::HasScalarInputVariant(intrinsic))
{
// Did not find a valid vector type. The intrinsic has alternate scalar version. Switch to that.

assert(sizeBytes == 0);
intrinsic = HWIntrinsicInfo::GetScalarInputVariant(intrinsic);
category = HWIntrinsicInfo::lookupCategory(intrinsic);
isa = HWIntrinsicInfo::lookupIsa(intrinsic);

simdBaseJitType = sig->retType;
assert(simdBaseJitType != CORINFO_TYPE_VOID);
assert(simdBaseJitType != CORINFO_TYPE_UNDEF);
a74nh marked this conversation as resolved.
Show resolved Hide resolved
assert(simdBaseJitType != CORINFO_TYPE_VALUECLASS);
}
else
#endif
{
assert((category == HW_Category_Special) || (category == HW_Category_Helper) || (sizeBytes != 0));
}
}
}

Expand Down
41 changes: 40 additions & 1 deletion src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,14 @@ enum HWIntrinsicFlag : unsigned int
// The intrinsic has an enum operand. Using this implies HW_Flag_HasImmediateOperand.
HW_Flag_HasEnumOperand = 0x1000000,

// The intrinsic comes in both vector and scalar variants. During the import stage if the basetype is scalar,
// then the intrinsic should be switched to a scalar only version.
HW_Flag_HasScalarInputVariant = 0x2000000,

#endif // TARGET_XARCH

// The intrinsic is a FusedMultiplyAdd intrinsic
HW_Flag_FmaIntrinsic = 0x20000000,
HW_Flag_FmaIntrinsic = 0x40000000,

#if defined(TARGET_ARM64)
// The intrinsic uses a mask in arg1 to select elements present in the result, and must use a low vector register.
Expand Down Expand Up @@ -929,6 +933,41 @@ struct HWIntrinsicInfo
return (flags & HW_Flag_HasEnumOperand) != 0;
}

static bool HasScalarInputVariant(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_HasScalarInputVariant) != 0;
}

static NamedIntrinsic GetScalarInputVariant(NamedIntrinsic id)
{
assert(HasScalarInputVariant(id));

switch (id)
a74nh marked this conversation as resolved.
Show resolved Hide resolved
{
case NI_Sve_SaturatingDecrementBy16BitElementCount:
return NI_Sve_SaturatingDecrementBy16BitElementCountScalar;

case NI_Sve_SaturatingDecrementBy32BitElementCount:
return NI_Sve_SaturatingDecrementBy32BitElementCountScalar;

case NI_Sve_SaturatingDecrementBy64BitElementCount:
return NI_Sve_SaturatingDecrementBy64BitElementCountScalar;

case NI_Sve_SaturatingIncrementBy16BitElementCount:
return NI_Sve_SaturatingIncrementBy16BitElementCountScalar;

case NI_Sve_SaturatingIncrementBy32BitElementCount:
return NI_Sve_SaturatingIncrementBy32BitElementCountScalar;

case NI_Sve_SaturatingIncrementBy64BitElementCount:
return NI_Sve_SaturatingIncrementBy64BitElementCountScalar;

default:
unreached();
}
}

#endif // TARGET_ARM64

static bool HasSpecialSideEffect(NamedIntrinsic id)
Expand Down
109 changes: 108 additions & 1 deletion src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,25 @@ void Compiler::getHWIntrinsicImmOps(NamedIntrinsic intrinsic,
imm2Pos = 0;
break;

case NI_Sve_SaturatingDecrementBy16BitElementCount:
case NI_Sve_SaturatingDecrementBy32BitElementCount:
case NI_Sve_SaturatingDecrementBy64BitElementCount:
case NI_Sve_SaturatingDecrementBy8BitElementCount:
case NI_Sve_SaturatingIncrementBy16BitElementCount:
case NI_Sve_SaturatingIncrementBy32BitElementCount:
case NI_Sve_SaturatingIncrementBy64BitElementCount:
case NI_Sve_SaturatingIncrementBy8BitElementCount:
case NI_Sve_SaturatingDecrementBy16BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy32BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy64BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy16BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy32BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy64BitElementCountScalar:
assert(sig->numArgs == 3);
imm1Pos = 1;
imm2Pos = 0;
kunalspathak marked this conversation as resolved.
Show resolved Hide resolved
break;

default:
assert(sig->numArgs > 0);
imm1Pos = 0;
Expand Down Expand Up @@ -447,6 +466,33 @@ void HWIntrinsicInfo::lookupImmBounds(
immUpperBound = (int)SVE_PATTERN_ALL;
break;

case NI_Sve_SaturatingDecrementBy16BitElementCount:
case NI_Sve_SaturatingDecrementBy32BitElementCount:
case NI_Sve_SaturatingDecrementBy64BitElementCount:
case NI_Sve_SaturatingDecrementBy8BitElementCount:
case NI_Sve_SaturatingIncrementBy16BitElementCount:
case NI_Sve_SaturatingIncrementBy32BitElementCount:
case NI_Sve_SaturatingIncrementBy64BitElementCount:
case NI_Sve_SaturatingIncrementBy8BitElementCount:
case NI_Sve_SaturatingDecrementBy16BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy32BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy64BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy16BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy32BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy64BitElementCountScalar:
if (immNumber == 1)
{
immLowerBound = 1;
immUpperBound = 16;
}
else
{
assert(immNumber == 2);
immLowerBound = (int)SVE_PATTERN_POW2;
immUpperBound = (int)SVE_PATTERN_ALL;
}
break;

default:
unreached();
}
Expand Down Expand Up @@ -513,7 +559,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
return gtNewScalarHWIntrinsicNode(TYP_VOID, intrinsic);
}

assert(category != HW_Category_Scalar);
bool isScalar = (category == HW_Category_Scalar);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we be explicit about this?

assert((category == HW_Category_Scalar) || (id == SaturatingDecrementBy8BitElementCount) || (id == SaturatingIncrementBy8BitElementCount))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point we've already potentially switched to the scalar variants. So the full assert would be:

    assert((category != HW_Category_Scalar) || (id == SaturatingDecrementBy8BitElementCount) || (id == SaturatingIncrementBy8BitElementCount)
            || (id == NI_Sve_SaturatingDecrementBy16BitElementCountScalar) || (id == NI_Sve_SaturatingDecrementBy32BitElementCountScalar)
            || (id == NI_Sve_SaturatingDecrementBy64BitElementCountScalar) || (id == NI_Sve_SaturatingIncrementBy16BitElementCountScalar)
            || (id == NI_Sve_SaturatingIncrementBy32BitElementCountScalar) || (id == NI_Sve_SaturatingIncrementBy64BitElementCountScalar));

Which feels excessive. Happy to switch if you still want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this by adding a DEBUG only bool.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are not HW_Category_Scalar and are not marked with isValidScalarIntrinsic ? does it work as expected?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for reminding me offline that they are HW_Category_SIMD and so should be fine.

assert(!HWIntrinsicInfo::isScalarIsa(HWIntrinsicInfo::lookupIsa(intrinsic)));

assert(numArgs >= 0);
Expand All @@ -527,6 +573,10 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
GenTree* op3 = nullptr;
GenTree* op4 = nullptr;

#ifdef DEBUG
bool isValidScalarIntrinsic = false;
#endif

switch (intrinsic)
{
case NI_Vector64_Abs:
Expand Down Expand Up @@ -2473,6 +2523,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
CORINFO_ARG_LIST_HANDLE arg3 = info.compCompHnd->getArgNext(arg2);
var_types argType = TYP_UNKNOWN;
CORINFO_CLASS_HANDLE argClass = NO_CLASS_HANDLE;

argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg3, &argClass)));
op3 = impPopStack().val;
unsigned fieldCount = info.compCompHnd->getClassNumInstanceFields(argClass);
Expand Down Expand Up @@ -2544,12 +2595,68 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Sve_SaturatingDecrementBy8BitElementCount:
case NI_Sve_SaturatingIncrementBy8BitElementCount:
case NI_Sve_SaturatingDecrementBy16BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy32BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy64BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy16BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy32BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy64BitElementCountScalar:
#ifdef DEBUG
isValidScalarIntrinsic = true;
FALLTHROUGH;
#endif
case NI_Sve_SaturatingDecrementBy16BitElementCount:
case NI_Sve_SaturatingDecrementBy32BitElementCount:
case NI_Sve_SaturatingDecrementBy64BitElementCount:
case NI_Sve_SaturatingIncrementBy16BitElementCount:
case NI_Sve_SaturatingIncrementBy32BitElementCount:
case NI_Sve_SaturatingIncrementBy64BitElementCount:

{
assert(sig->numArgs == 3);

CORINFO_ARG_LIST_HANDLE arg1 = sig->args;
CORINFO_ARG_LIST_HANDLE arg2 = info.compCompHnd->getArgNext(arg1);
CORINFO_ARG_LIST_HANDLE arg3 = info.compCompHnd->getArgNext(arg2);
var_types argType = TYP_UNKNOWN;
CORINFO_CLASS_HANDLE argClass = NO_CLASS_HANDLE;
int immLowerBound = 0;
int immUpperBound = 0;

argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg3, &argClass)));
op3 = getArgForHWIntrinsic(argType, argClass);
argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg2, &argClass)));
op2 = getArgForHWIntrinsic(argType, argClass);
argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg1, &argClass)));
op1 = impPopStack().val;

CorInfoType op1BaseJitType = getBaseJitTypeOfSIMDType(argClass);

assert(HWIntrinsicInfo::isImmOp(intrinsic, op2));
HWIntrinsicInfo::lookupImmBounds(intrinsic, simdSize, simdBaseType, 1, &immLowerBound, &immUpperBound);
op2 = addRangeCheckIfNeeded(intrinsic, op2, (!op2->IsCnsIntOrI()), immLowerBound, immUpperBound);

assert(HWIntrinsicInfo::isImmOp(intrinsic, op3));
HWIntrinsicInfo::lookupImmBounds(intrinsic, simdSize, simdBaseType, 2, &immLowerBound, &immUpperBound);
op3 = addRangeCheckIfNeeded(intrinsic, op3, (!op3->IsCnsIntOrI()), immLowerBound, immUpperBound);

retNode = isScalar ? gtNewScalarHWIntrinsicNode(retType, op1, op2, op3, intrinsic)
: gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, intrinsic, simdBaseJitType, simdSize);

retNode->AsHWIntrinsic()->SetSimdBaseJitType(simdBaseJitType);
break;
}

default:
{
return nullptr;
}
}

assert(!isScalar || isValidScalarIntrinsic);

return retNode;
}

Expand Down
Loading
Loading