Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT ARM64-SVE: Add saturating decrement/increment by element count #102315

Merged
merged 32 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
51c328b
JIT ARM64-SVE: Add saturating decrement/increment by element count
a74nh May 3, 2024
286df97
Merge main
a74nh May 17, 2024
a9f6aaa
Add fallback for scalar variants
a74nh May 17, 2024
a40951f
Add fallback for vector variants
a74nh May 17, 2024
a43e870
HW_Flag_HasScalarVariant comment
a74nh May 20, 2024
e700111
Add IsCnsIntOrI asserts
a74nh May 20, 2024
a6cb5e3
Simpler SubtractSaturateScalar()
a74nh May 20, 2024
1e6082f
Combine codegen cases
a74nh May 20, 2024
84a5d22
Inline DecodePredicateCount()
a74nh May 20, 2024
4a9dc8a
Remove C# fallbacks
a74nh May 21, 2024
5025a8c
Add HWIntrinsicImmOpHelper variant
a74nh May 23, 2024
b79c21d
Add special import
a74nh May 28, 2024
ef74d46
Add out of range testing
a74nh May 28, 2024
999b5e4
Expand testing
a74nh May 28, 2024
e82dc4e
Add testing for different patterns
a74nh May 28, 2024
d021f18
Merge main
a74nh May 28, 2024
9780f7d
formatting
a74nh May 28, 2024
646f03c
Fix merge failure
a74nh May 28, 2024
9c4bec7
Remove entries of SaturatingIncrementByActiveElementCount()
kunalspathak May 28, 2024
cd779d5
HW_Flag_HasScalarInputVariant
a74nh May 29, 2024
c9569fc
review cleanups
a74nh May 29, 2024
38a5614
Fix comments from X86 to Arm
a74nh May 29, 2024
b66bab7
Add missing tests
a74nh May 29, 2024
4414b33
Add missing min/max for ConstantExpected
a74nh May 29, 2024
b034ea7
formatting
a74nh May 29, 2024
a536c01
Add isValidScalarIntrinsic check
a74nh May 29, 2024
1c11ce2
Add RMW asserts
a74nh May 30, 2024
6a3dfea
Move internal register use after dest def
a74nh May 31, 2024
5cbd267
Merge main
a74nh May 31, 2024
e7047ae
Use Saturating helpers in testing
a74nh May 31, 2024
975eb1e
Ensure Saturating helpers use the correct size
a74nh May 31, 2024
f6a232e
Set internal register as delay for Saturating
a74nh May 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/coreclr/jit/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,13 @@ class CodeGen final : public CodeGenInterface
class HWIntrinsicImmOpHelper final
{
public:
HWIntrinsicImmOpHelper(CodeGen* codeGen, GenTree* immOp, GenTreeHWIntrinsic* intrin, int immNum = 1);
HWIntrinsicImmOpHelper(CodeGen* codeGen, GenTree* immOp, GenTreeHWIntrinsic* intrin);

HWIntrinsicImmOpHelper(CodeGen* codeGen,
regNumber nonConstImmReg,
int immLowerBound,
int immUpperBound,
GenTreeHWIntrinsic* intrin);

void EmitBegin();
void EmitCaseEnd();
Expand Down Expand Up @@ -1040,6 +1046,7 @@ class CodeGen final : public CodeGenInterface
regNumber nonConstImmReg;
regNumber branchTargetReg;
};

#endif // TARGET_ARM64

#endif // FEATURE_HW_INTRINSICS
Expand Down
21 changes: 20 additions & 1 deletion src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,26 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
unsigned int sizeBytes;

simdBaseJitType = getBaseJitTypeAndSizeOfSIMDType(clsHnd, &sizeBytes);
assert((category == HW_Category_Special) || (category == HW_Category_Helper) || (sizeBytes != 0));

#if defined(TARGET_ARM64)
if (simdBaseJitType == CORINFO_TYPE_UNDEF && HWIntrinsicInfo::HasScalarVariant(intrinsic))
{
// Did not find a valid vector type. The intrinsic has alternate scalar version. Switch to that.

assert(sizeBytes == 0);
intrinsic = HWIntrinsicInfo::GetScalarVariant(intrinsic);
category = HWIntrinsicInfo::lookupCategory(intrinsic);
isa = HWIntrinsicInfo::lookupIsa(intrinsic);

simdBaseJitType = sig->retType;
assert(simdBaseJitType != CORINFO_TYPE_VOID);
assert(simdBaseJitType != CORINFO_TYPE_UNDEF);
a74nh marked this conversation as resolved.
Show resolved Hide resolved
}
else
#endif
{
assert((category == HW_Category_Special) || (category == HW_Category_Helper) || (sizeBytes != 0));
}
}
}

Expand Down
39 changes: 38 additions & 1 deletion src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,14 @@ enum HWIntrinsicFlag : unsigned int
// The intrinsic has an enum operand. Using this implies HW_Flag_HasImmediateOperand.
HW_Flag_HasEnumOperand = 0x1000000,

// The intrinsic comes in both vector and scalar variants. During the import stage if the basetype is scalar,
// then the intrinsic should be switched to a scalar only version.
HW_Flag_HasScalarVariant = 0x2000000,
a74nh marked this conversation as resolved.
Show resolved Hide resolved
a74nh marked this conversation as resolved.
Show resolved Hide resolved

#endif // TARGET_XARCH

// The intrinsic is a FusedMultiplyAdd intrinsic
HW_Flag_FmaIntrinsic = 0x20000000,
HW_Flag_FmaIntrinsic = 0x40000000,

#if defined(TARGET_ARM64)
// The intrinsic uses a mask in arg1 to select elements present in the result, and must use a low vector register.
Expand Down Expand Up @@ -926,6 +930,39 @@ struct HWIntrinsicInfo
return (flags & HW_Flag_HasEnumOperand) != 0;
}

static bool HasScalarVariant(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_HasScalarVariant) != 0;
}

static NamedIntrinsic GetScalarVariant(NamedIntrinsic id)
{
switch (id)
a74nh marked this conversation as resolved.
Show resolved Hide resolved
{
case NI_Sve_SaturatingDecrementBy16BitElementCount:
return NI_Sve_SaturatingDecrementBy16BitElementCountScalar;

case NI_Sve_SaturatingDecrementBy32BitElementCount:
return NI_Sve_SaturatingDecrementBy32BitElementCountScalar;

case NI_Sve_SaturatingDecrementBy64BitElementCount:
return NI_Sve_SaturatingDecrementBy64BitElementCountScalar;

case NI_Sve_SaturatingIncrementBy16BitElementCount:
return NI_Sve_SaturatingIncrementBy16BitElementCountScalar;

case NI_Sve_SaturatingIncrementBy32BitElementCount:
return NI_Sve_SaturatingIncrementBy32BitElementCountScalar;

case NI_Sve_SaturatingIncrementBy64BitElementCount:
return NI_Sve_SaturatingIncrementBy64BitElementCountScalar;

default:
unreached();
}
}

#endif // TARGET_ARM64

static bool HasSpecialSideEffect(NamedIntrinsic id)
Expand Down
98 changes: 97 additions & 1 deletion src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,25 @@ void Compiler::getHWIntrinsicImmOps(NamedIntrinsic intrinsic,
imm2Pos = 0;
break;

case NI_Sve_SaturatingDecrementBy16BitElementCount:
case NI_Sve_SaturatingDecrementBy32BitElementCount:
case NI_Sve_SaturatingDecrementBy64BitElementCount:
case NI_Sve_SaturatingDecrementBy8BitElementCount:
case NI_Sve_SaturatingIncrementBy16BitElementCount:
case NI_Sve_SaturatingIncrementBy32BitElementCount:
case NI_Sve_SaturatingIncrementBy64BitElementCount:
case NI_Sve_SaturatingIncrementBy8BitElementCount:
case NI_Sve_SaturatingDecrementBy16BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy32BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy64BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy16BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy32BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy64BitElementCountScalar:
assert(sig->numArgs == 3);
imm1Pos = 1;
imm2Pos = 0;
kunalspathak marked this conversation as resolved.
Show resolved Hide resolved
break;

default:
assert(sig->numArgs > 0);
imm1Pos = 0;
Expand Down Expand Up @@ -447,6 +466,33 @@ void HWIntrinsicInfo::lookupImmBounds(
immUpperBound = (int)SVE_PATTERN_ALL;
break;

case NI_Sve_SaturatingDecrementBy16BitElementCount:
case NI_Sve_SaturatingDecrementBy32BitElementCount:
case NI_Sve_SaturatingDecrementBy64BitElementCount:
case NI_Sve_SaturatingDecrementBy8BitElementCount:
case NI_Sve_SaturatingIncrementBy16BitElementCount:
case NI_Sve_SaturatingIncrementBy32BitElementCount:
case NI_Sve_SaturatingIncrementBy64BitElementCount:
case NI_Sve_SaturatingIncrementBy8BitElementCount:
case NI_Sve_SaturatingDecrementBy16BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy32BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy64BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy16BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy32BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy64BitElementCountScalar:
if (immNumber == 1)
{
immLowerBound = 1;
immUpperBound = 16;
}
else
{
assert(immNumber == 2);
immLowerBound = (int)SVE_PATTERN_POW2;
immUpperBound = (int)SVE_PATTERN_ALL;
}
break;

default:
unreached();
}
Expand Down Expand Up @@ -512,7 +558,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
return gtNewScalarHWIntrinsicNode(TYP_VOID, intrinsic);
}

assert(category != HW_Category_Scalar);
bool isScalar = (category == HW_Category_Scalar);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we be explicit about this?

assert((category == HW_Category_Scalar) || (id == SaturatingDecrementBy8BitElementCount) || (id == SaturatingIncrementBy8BitElementCount))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point we've already potentially switched to the scalar variants. So the full assert would be:

    assert((category != HW_Category_Scalar) || (id == SaturatingDecrementBy8BitElementCount) || (id == SaturatingIncrementBy8BitElementCount)
            || (id == NI_Sve_SaturatingDecrementBy16BitElementCountScalar) || (id == NI_Sve_SaturatingDecrementBy32BitElementCountScalar)
            || (id == NI_Sve_SaturatingDecrementBy64BitElementCountScalar) || (id == NI_Sve_SaturatingIncrementBy16BitElementCountScalar)
            || (id == NI_Sve_SaturatingIncrementBy32BitElementCountScalar) || (id == NI_Sve_SaturatingIncrementBy64BitElementCountScalar));

Which feels excessive. Happy to switch if you still want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this by adding a DEBUG only bool.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are not HW_Category_Scalar and are not marked with isValidScalarIntrinsic ? does it work as expected?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for reminding me offline that they are HW_Category_SIMD and so should be fine.

assert(!HWIntrinsicInfo::isScalarIsa(HWIntrinsicInfo::lookupIsa(intrinsic)));

assert(numArgs >= 0);
Expand Down Expand Up @@ -2436,6 +2482,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
CORINFO_ARG_LIST_HANDLE arg3 = info.compCompHnd->getArgNext(arg2);
var_types argType = TYP_UNKNOWN;
CORINFO_CLASS_HANDLE argClass = NO_CLASS_HANDLE;

argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg3, &argClass)));
op3 = impPopStack().val;
unsigned fieldCount = info.compCompHnd->getClassNumInstanceFields(argClass);
Expand Down Expand Up @@ -2507,6 +2554,55 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Sve_SaturatingDecrementBy16BitElementCount:
case NI_Sve_SaturatingDecrementBy32BitElementCount:
case NI_Sve_SaturatingDecrementBy64BitElementCount:
case NI_Sve_SaturatingDecrementBy8BitElementCount:
case NI_Sve_SaturatingIncrementBy16BitElementCount:
case NI_Sve_SaturatingIncrementBy32BitElementCount:
case NI_Sve_SaturatingIncrementBy64BitElementCount:
case NI_Sve_SaturatingIncrementBy8BitElementCount:
case NI_Sve_SaturatingDecrementBy16BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy32BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy64BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy16BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy32BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy64BitElementCountScalar:
{
assert(sig->numArgs == 3);

CORINFO_ARG_LIST_HANDLE arg1 = sig->args;
CORINFO_ARG_LIST_HANDLE arg2 = info.compCompHnd->getArgNext(arg1);
CORINFO_ARG_LIST_HANDLE arg3 = info.compCompHnd->getArgNext(arg2);
var_types argType = TYP_UNKNOWN;
CORINFO_CLASS_HANDLE argClass = NO_CLASS_HANDLE;
int immLowerBound = 0;
int immUpperBound = 0;

argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg3, &argClass)));
op3 = getArgForHWIntrinsic(argType, argClass);
argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg2, &argClass)));
op2 = getArgForHWIntrinsic(argType, argClass);
argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg1, &argClass)));
op1 = impPopStack().val;

CorInfoType op1BaseJitType = getBaseJitTypeOfSIMDType(argClass);

assert(HWIntrinsicInfo::isImmOp(intrinsic, op2));
HWIntrinsicInfo::lookupImmBounds(intrinsic, simdSize, simdBaseType, 1, &immLowerBound, &immUpperBound);
op2 = addRangeCheckIfNeeded(intrinsic, op2, (!op2->IsCnsIntOrI()), immLowerBound, immUpperBound);

assert(HWIntrinsicInfo::isImmOp(intrinsic, op3));
HWIntrinsicInfo::lookupImmBounds(intrinsic, simdSize, simdBaseType, 2, &immLowerBound, &immUpperBound);
op3 = addRangeCheckIfNeeded(intrinsic, op3, (!op3->IsCnsIntOrI()), immLowerBound, immUpperBound);

retNode = isScalar ? gtNewScalarHWIntrinsicNode(retType, op1, op2, op3, intrinsic)
: gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, intrinsic, simdBaseJitType, simdSize);

retNode->AsHWIntrinsic()->SetSimdBaseJitType(simdBaseJitType);
break;
}

default:
{
return nullptr;
Expand Down
107 changes: 100 additions & 7 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
// codeGen -- an instance of CodeGen class.
// immOp -- an immediate operand of the intrinsic.
// intrin -- a hardware intrinsic tree node.
// immNumber -- which immediate operand to use (most intrinsics only have one).
//
// Note: This class is designed to be used in the following way
// HWIntrinsicImmOpHelper helper(this, immOp, intrin);
Expand All @@ -36,10 +35,7 @@
// This allows to combine logic for cases when immOp->isContainedIntOrIImmed() is either true or false in a form
// of a for-loop.
//
CodeGen::HWIntrinsicImmOpHelper::HWIntrinsicImmOpHelper(CodeGen* codeGen,
GenTree* immOp,
GenTreeHWIntrinsic* intrin,
int immNumber /* = 1 */)
CodeGen::HWIntrinsicImmOpHelper::HWIntrinsicImmOpHelper(CodeGen* codeGen, GenTree* immOp, GenTreeHWIntrinsic* intrin)
: codeGen(codeGen)
, endLabel(nullptr)
, nonZeroLabel(nullptr)
Expand Down Expand Up @@ -79,12 +75,12 @@ CodeGen::HWIntrinsicImmOpHelper::HWIntrinsicImmOpHelper(CodeGen* code

const unsigned int indexedElementSimdSize = genTypeSize(indexedElementOpType);
HWIntrinsicInfo::lookupImmBounds(intrin->GetHWIntrinsicId(), indexedElementSimdSize,
intrin->GetSimdBaseType(), immNumber, &immLowerBound, &immUpperBound);
intrin->GetSimdBaseType(), 1, &immLowerBound, &immUpperBound);
}
else
{
HWIntrinsicInfo::lookupImmBounds(intrin->GetHWIntrinsicId(), intrin->GetSimdSize(),
intrin->GetSimdBaseType(), immNumber, &immLowerBound, &immUpperBound);
intrin->GetSimdBaseType(), 1, &immLowerBound, &immUpperBound);
}

nonConstImmReg = immOp->GetRegNum();
Expand All @@ -109,6 +105,37 @@ CodeGen::HWIntrinsicImmOpHelper::HWIntrinsicImmOpHelper(CodeGen* code
}
}

CodeGen::HWIntrinsicImmOpHelper::HWIntrinsicImmOpHelper(
a74nh marked this conversation as resolved.
Show resolved Hide resolved
CodeGen* codeGen, regNumber nonConstImmReg, int immLowerBound, int immUpperBound, GenTreeHWIntrinsic* intrin)
: codeGen(codeGen)
, endLabel(nullptr)
, nonZeroLabel(nullptr)
, immValue(immLowerBound)
, immLowerBound(immLowerBound)
, immUpperBound(immUpperBound)
, nonConstImmReg(nonConstImmReg)
, branchTargetReg(REG_NA)
{
assert(codeGen != nullptr);

if (TestImmOpZeroOrOne())
{
nonZeroLabel = codeGen->genCreateTempLabel();
}
else
{
// At the moment, this helper supports only intrinsics that correspond to one machine instruction.
// If we ever encounter an intrinsic that is either lowered into multiple instructions or
// the number of instructions that correspond to each case is unknown apriori - we can extend support to
// these by
// using the same approach as in hwintrinsicxarch.cpp - adding an additional indirection level in form of a
// branch table.
branchTargetReg = codeGen->internalRegisters.GetSingle(intrin);
}

endLabel = codeGen->genCreateTempLabel();
}

//------------------------------------------------------------------------
// EmitBegin: emits the beginning of a "switch" table, no-op if an immediate operand is constant.
//
Expand Down Expand Up @@ -1791,6 +1818,72 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
INS_SCALABLE_OPTS_UNPREDICATED);
break;

case NI_Sve_SaturatingDecrementBy16BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy32BitElementCountScalar:
case NI_Sve_SaturatingDecrementBy64BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy16BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy32BitElementCountScalar:
case NI_Sve_SaturatingIncrementBy64BitElementCountScalar:
// Use scalar sizes.
emitSize = emitActualTypeSize(node->gtType);
opt = INS_OPTS_NONE;
FALLTHROUGH;

case NI_Sve_SaturatingDecrementBy16BitElementCount:
case NI_Sve_SaturatingDecrementBy32BitElementCount:
case NI_Sve_SaturatingDecrementBy64BitElementCount:
case NI_Sve_SaturatingDecrementBy8BitElementCount:
case NI_Sve_SaturatingIncrementBy16BitElementCount:
case NI_Sve_SaturatingIncrementBy32BitElementCount:
case NI_Sve_SaturatingIncrementBy64BitElementCount:
case NI_Sve_SaturatingIncrementBy8BitElementCount:
{
assert(isRMW);
if (targetReg != op1Reg)
{
GetEmitter()->emitIns_Mov(INS_mov, emitTypeSize(node), targetReg, op1Reg, /* canSkip */ true);
}

if (intrin.op2->IsCnsIntOrI() && intrin.op3->IsCnsIntOrI())
{
// Both immediates are constant, emit the intruction.

assert(intrin.op2->isContainedIntOrIImmed() && intrin.op3->isContainedIntOrIImmed());
int scale = (int)intrin.op2->AsIntCon()->gtIconVal;
insSvePattern pattern = (insSvePattern)intrin.op3->AsIntCon()->gtIconVal;
GetEmitter()->emitIns_R_PATTERN_I(ins, emitSize, targetReg, pattern, scale, opt);
}
else
{
// Use the helper to generate a table.

assert(!intrin.op2->isContainedIntOrIImmed() && !intrin.op3->isContainedIntOrIImmed());

emitAttr scalarSize = emitActualTypeSize(node->GetSimdBaseType());

// Combine the second immediate (pattern, op3) into the first (scale, op2).
GetEmitter()->emitIns_R_R_I(INS_sub, scalarSize, op2Reg, op2Reg, 1);
a74nh marked this conversation as resolved.
Show resolved Hide resolved
GetEmitter()->emitIns_R_R_I(INS_lsl, scalarSize, op3Reg, op3Reg, 4);
GetEmitter()->emitIns_R_R_R(INS_orr, scalarSize, op2Reg, op2Reg, op3Reg);

// Generate a table using the combined immediate.
HWIntrinsicImmOpHelper helper(this, op2Reg, 0, 511, node);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
const int value = helper.ImmValue();
const int scale = (value & 0xF) + 1;
const insSvePattern pattern = (insSvePattern)(value >> 4);
GetEmitter()->emitIns_R_PATTERN_I(ins, emitSize, targetReg, pattern, scale, opt);
}

// Restore the immediates.
GetEmitter()->emitIns_R_R_I(INS_and, scalarSize, op2Reg, op2Reg, 0xF);
GetEmitter()->emitIns_R_R_I(INS_lsr, scalarSize, op3Reg, op3Reg, 4);
GetEmitter()->emitIns_R_R_I(INS_add, scalarSize, op2Reg, op2Reg, 1);
}
break;
}

default:
unreached();
}
Expand Down
Loading
Loading