Skip to content

Commit

Permalink
Arm64/Sve: Predicated Abs, Predicated/UnPredicated Add, Conditional S…
Browse files Browse the repository at this point in the history
…elect (#100743)

* JIT ARM64-SVE: Add Sve.Abs() and Sve.Add()

Change-Id: Ie8cfe828595da9a87adbc0857c0c44c0ce12f5b2

* Fix sve scaling in enitIns_R_S/S_R

* Revert "Fix sve scaling in enitIns_R_S/S_R"

This reverts commit e9fa735.

* Fix sve scaling in enitIns_R_S/S_R

* Restore testing

* Use NaturalScale_helper for vector load/stores

* wip

* Add ConditionalSelect() APIs

* Handle ConditionalSelect in JIT

* Add test coverage

* Update the test cases

* jit format

* fix merge conflicts

* Make predicated/unpredicated work with ConditionalSelect

Still some handling around RMW is needed, but this basically works

* Misc. changes

* jit format

* jit format

* Handle all the conditions correctly

* jit format

* fix some spacing

* Removed the assert

* fix the largest vector size to 64 to fix #100366

* review feedback

* wip

* Add SVE feature detection for Windows

* fix the check for invalid alignment

* Revert "Add SVE feature detection for Windows"

This reverts commit ed7c781.

* Handle case where Abs() is wrapped in another conditionalSelect

* jit format

* fix the size comparison

* HW_Flag_MaskedPredicatedOnlyOperation

* Revert the change in emitarm64.cpp around INS_sve_ldr_mask/INS_sve_str_mask

* Fix the condition for lowering

* address review feedback for movprfx

* Move the special handling of Vector<>.Zero from lowerer to importer

* Rename IsEmbeddedMaskedOperation/IsOptionalEmbeddedMaskedOperation

* Add more test coverage for conditionalSelect

* Rename test method name

* Add more test coverage for conditionalSelect:Abs

* jit format

* Add logging on test methods

* Add the missing movprfx for abs

* Add few more scenarios where falseVal is zero

* Make sure LoadVector is marked as explicit needing mask

* revisit the codegen logic

* Remove commented code and add some other comments

* jit format

---------

Co-authored-by: Alan Hayward <alan.hayward@arm.com>
  • Loading branch information
kunalspathak and a74nh authored Apr 24, 2024
1 parent 1a93a9d commit dd24e9c
Show file tree
Hide file tree
Showing 24 changed files with 1,994 additions and 74 deletions.
2 changes: 0 additions & 2 deletions src/coreclr/jit/codegenlinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1648,15 +1648,13 @@ void CodeGen::genConsumeRegs(GenTree* tree)
// Update the life of the lcl var.
genUpdateLife(tree);
}
#ifdef TARGET_XARCH
#ifdef FEATURE_HW_INTRINSICS
else if (tree->OperIs(GT_HWINTRINSIC))
{
GenTreeHWIntrinsic* hwintrinsic = tree->AsHWIntrinsic();
genConsumeMultiOpOperands(hwintrinsic);
}
#endif // FEATURE_HW_INTRINSICS
#endif // TARGET_XARCH
else if (tree->OperIs(GT_BITCAST, GT_NEG, GT_CAST, GT_LSH, GT_RSH, GT_RSZ, GT_ROR, GT_BSWAP, GT_BSWAP16))
{
genConsumeRegs(tree->gtGetOp1());
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3477,6 +3477,7 @@ class Compiler
#if defined(TARGET_ARM64)
GenTree* gtNewSimdConvertVectorToMaskNode(var_types type, GenTree* node, CorInfoType simdBaseJitType, unsigned simdSize);
GenTree* gtNewSimdConvertMaskToVectorNode(GenTreeHWIntrinsic* node, var_types type);
GenTree* gtNewSimdAllTrueMaskNode(CorInfoType simdBaseJitType, unsigned simdSize);
#endif

//------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/emitloongarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ enum EmitCallType

EC_FUNC_TOKEN, // Direct call to a helper/static/nonvirtual/global method
// EC_FUNC_TOKEN_INDIR, // Indirect call to a helper/static/nonvirtual/global method
// EC_FUNC_ADDR, // Direct call to an absolute address
// EC_FUNC_ADDR, // Direct call to an absolute address

EC_INDIR_R, // Indirect call via register

Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/emitriscv64.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ enum EmitCallType

EC_FUNC_TOKEN, // Direct call to a helper/static/nonvirtual/global method
// EC_FUNC_TOKEN_INDIR, // Indirect call to a helper/static/nonvirtual/global method
// EC_FUNC_ADDR, // Direct call to an absolute address
// EC_FUNC_ADDR, // Direct call to an absolute address

// EC_FUNC_VIRTUAL, // Call to a virtual method (using the vtable)
EC_INDIR_R, // Indirect call via register
Expand Down
14 changes: 8 additions & 6 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18012,7 +18012,7 @@ bool GenTree::canBeContained() const
}
else if (OperIsHWIntrinsic() && !isContainableHWIntrinsic())
{
return isEvexEmbeddedMaskingCompatibleHWIntrinsic();
return isEmbeddedMaskingCompatibleHWIntrinsic();
}

return true;
Expand Down Expand Up @@ -19909,24 +19909,26 @@ bool GenTree::isEvexCompatibleHWIntrinsic() const
}

//------------------------------------------------------------------------
// isEvexEmbeddedMaskingCompatibleHWIntrinsic: Checks if the intrinsic is compatible
// isEmbeddedMaskingCompatibleHWIntrinsic : Checks if the intrinsic is compatible
// with the EVEX embedded masking form for its intended lowering instruction.
//
// Return Value:
// true if the intrisic node lowering instruction has an EVEX embedded masking
//
bool GenTree::isEvexEmbeddedMaskingCompatibleHWIntrinsic() const
bool GenTree::isEmbeddedMaskingCompatibleHWIntrinsic() const
{
#if defined(TARGET_XARCH)
if (OperIsHWIntrinsic())
{
#if defined(TARGET_XARCH)
// TODO-AVX512F-CQ: Expand this to the full set of APIs and make it table driven
// using IsEmbMaskingCompatible. For now, however, limit it to some explicit ids
// for prototyping purposes.
return (AsHWIntrinsic()->GetHWIntrinsicId() == NI_AVX512F_Add);
#elif defined(TARGET_ARM64)
return HWIntrinsicInfo::IsEmbeddedMaskedOperation(AsHWIntrinsic()->GetHWIntrinsicId()) ||
HWIntrinsicInfo::IsOptionalEmbeddedMaskedOperation(AsHWIntrinsic()->GetHWIntrinsicId());
#endif
}
#endif // TARGET_XARCH

return false;
}

Expand Down
12 changes: 6 additions & 6 deletions src/coreclr/jit/gentree.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ enum GenTreeFlags : unsigned int

GTF_MDARRLOWERBOUND_NONFAULTING = 0x20000000, // GT_MDARR_LOWER_BOUND -- An MD array lower bound operation that cannot fault. Same as GT_IND_NONFAULTING.

#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS)
#ifdef FEATURE_HW_INTRINSICS
GTF_HW_EM_OP = 0x10000000, // GT_HWINTRINSIC -- node is used as an operand to an embedded mask
#endif // TARGET_XARCH && FEATURE_HW_INTRINSICS
#endif // FEATURE_HW_INTRINSICS
};

inline constexpr GenTreeFlags operator ~(GenTreeFlags a)
Expand Down Expand Up @@ -1465,7 +1465,7 @@ struct GenTree
bool isContainableHWIntrinsic() const;
bool isRMWHWIntrinsic(Compiler* comp);
bool isEvexCompatibleHWIntrinsic() const;
bool isEvexEmbeddedMaskingCompatibleHWIntrinsic() const;
bool isEmbeddedMaskingCompatibleHWIntrinsic() const;
#else
bool isCommutativeHWIntrinsic() const
{
Expand All @@ -1487,7 +1487,7 @@ struct GenTree
return false;
}

bool isEvexEmbeddedMaskingCompatibleHWIntrinsic() const
bool isEmbeddedMaskingCompatibleHWIntrinsic() const
{
return false;
}
Expand Down Expand Up @@ -2226,7 +2226,7 @@ struct GenTree
gtFlags &= ~GTF_ICON_HDL_MASK;
}

#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS)
#ifdef FEATURE_HW_INTRINSICS

bool IsEmbMaskOp()
{
Expand All @@ -2240,7 +2240,7 @@ struct GenTree
gtFlags |= GTF_HW_EM_OP;
}

#endif // TARGET_XARCH && FEATURE_HW_INTRINSICS
#endif // FEATURE_HW_INTRINSICS

static bool HandleKindDataIsInvariant(GenTreeFlags flags);

Expand Down
60 changes: 43 additions & 17 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1396,6 +1396,36 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
GenTree* op3 = nullptr;
GenTree* op4 = nullptr;

switch (numArgs)
{
case 4:
op4 = getArgForHWIntrinsic(sigReader.GetOp4Type(), sigReader.op4ClsHnd);
op4 = addRangeCheckIfNeeded(intrinsic, op4, mustExpand, immLowerBound, immUpperBound);
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

case 3:
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

case 2:
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

case 1:
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

default:
break;
}

switch (numArgs)
{
case 0:
Expand All @@ -1407,8 +1437,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 1:
{
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

if ((category == HW_Category_MemoryLoad) && op1->OperIs(GT_CAST))
{
// Although the API specifies a pointer, if what we have is a BYREF, that's what
Expand Down Expand Up @@ -1467,10 +1495,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 2:
{
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

retNode = isScalar
? gtNewScalarHWIntrinsicNode(nodeRetType, op1, op2, intrinsic)
: gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, intrinsic, simdBaseJitType, simdSize);
Expand Down Expand Up @@ -1524,10 +1548,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 3:
{
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

#ifdef TARGET_ARM64
if (intrinsic == NI_AdvSimd_LoadAndInsertScalar)
{
Expand Down Expand Up @@ -1569,12 +1589,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 4:
{
op4 = getArgForHWIntrinsic(sigReader.GetOp4Type(), sigReader.op4ClsHnd);
op4 = addRangeCheckIfNeeded(intrinsic, op4, mustExpand, immLowerBound, immUpperBound);
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

assert(!isScalar);
retNode =
gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, op3, op4, intrinsic, simdBaseJitType, simdSize);
Expand All @@ -1591,10 +1605,22 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}

#if defined(TARGET_ARM64)
if (HWIntrinsicInfo::IsMaskedOperation(intrinsic))
if (HWIntrinsicInfo::IsExplicitMaskedOperation(intrinsic))
{
assert(numArgs > 0);
GenTree* op1 = retNode->AsHWIntrinsic()->Op(1);
if (intrinsic == NI_Sve_ConditionalSelect)
{
if (op1->IsVectorAllBitsSet())
{
return retNode->AsHWIntrinsic()->Op(2);
}
else if (op1->IsVectorZero())
{
return retNode->AsHWIntrinsic()->Op(3);
}
}

if (!varTypeIsMask(op1))
{
// Op1 input is a vector. HWInstrinsic requires a mask.
Expand Down
29 changes: 27 additions & 2 deletions src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,18 @@ enum HWIntrinsicFlag : unsigned int
HW_Flag_ReturnsPerElementMask = 0x10000,

// The intrinsic uses a mask in arg1 to select elements present in the result
HW_Flag_MaskedOperation = 0x20000,
HW_Flag_ExplicitMaskedOperation = 0x20000,

// The intrinsic uses a mask in arg1 to select elements present in the result, and must use a low register.
HW_Flag_LowMaskedOperation = 0x40000,

// The intrinsic can optionally use a mask in arg1 to select elements present in the result, which is not present in
// the API call
HW_Flag_OptionalEmbeddedMaskedOperation = 0x80000,

// The intrinsic uses a mask in arg1 to select elements present in the result, which is not present in the API call
HW_Flag_EmbeddedMaskedOperation = 0x100000,

#else
#error Unsupported platform
#endif
Expand Down Expand Up @@ -872,7 +879,7 @@ struct HWIntrinsicInfo
static bool IsMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return ((flags & HW_Flag_MaskedOperation) != 0) || IsLowMaskedOperation(id);
return IsLowMaskedOperation(id) || IsOptionalEmbeddedMaskedOperation(id) || IsExplicitMaskedOperation(id);
}

static bool IsLowMaskedOperation(NamedIntrinsic id)
Expand All @@ -881,6 +888,24 @@ struct HWIntrinsicInfo
return (flags & HW_Flag_LowMaskedOperation) != 0;
}

static bool IsOptionalEmbeddedMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_OptionalEmbeddedMaskedOperation) != 0;
}

static bool IsEmbeddedMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_EmbeddedMaskedOperation) != 0;
}

static bool IsExplicitMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_ExplicitMaskedOperation) != 0;
}

#endif // TARGET_ARM64

static bool HasSpecialSideEffect(NamedIntrinsic id)
Expand Down
20 changes: 17 additions & 3 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2222,9 +2222,8 @@ GenTree* Compiler::gtNewSimdConvertVectorToMaskNode(var_types type,
assert(varTypeIsSIMD(node));

// ConvertVectorToMask uses cmpne which requires an embedded mask.
GenTree* embeddedMask = gtNewSimdHWIntrinsicNode(TYP_MASK, NI_Sve_CreateTrueMaskAll, simdBaseJitType, simdSize);
return gtNewSimdHWIntrinsicNode(TYP_MASK, embeddedMask, node, NI_Sve_ConvertVectorToMask, simdBaseJitType,
simdSize);
GenTree* trueMask = gtNewSimdAllTrueMaskNode(simdBaseJitType, simdSize);
return gtNewSimdHWIntrinsicNode(TYP_MASK, trueMask, node, NI_Sve_ConvertVectorToMask, simdBaseJitType, simdSize);
}

//------------------------------------------------------------------------
Expand All @@ -2246,4 +2245,19 @@ GenTree* Compiler::gtNewSimdConvertMaskToVectorNode(GenTreeHWIntrinsic* node, va
node->GetSimdSize());
}

//------------------------------------------------------------------------
// gtNewSimdEmbeddedMaskNode: Create an embedded mask
//
// Arguments:
// simdBaseJitType -- the base jit type of the nodes being masked
// simdSize -- the simd size of the nodes being masked
//
// Return Value:
// The mask
//
GenTree* Compiler::gtNewSimdAllTrueMaskNode(CorInfoType simdBaseJitType, unsigned simdSize)
{
return gtNewSimdHWIntrinsicNode(TYP_MASK, NI_Sve_CreateTrueMaskAll, simdBaseJitType, simdSize);
}

#endif // FEATURE_HW_INTRINSICS
Loading

0 comments on commit dd24e9c

Please sign in to comment.