Skip to content

Commit

Permalink
JIT ARM64-SVE: Add Sve.Abs() and Sve.Add()
Browse files Browse the repository at this point in the history
Change-Id: Ie8cfe828595da9a87adbc0857c0c44c0ce12f5b2
  • Loading branch information
a74nh committed Mar 22, 2024
1 parent cfe3d2d commit 0d437e3
Show file tree
Hide file tree
Showing 12 changed files with 1,070 additions and 64 deletions.
1 change: 1 addition & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3467,6 +3467,7 @@ class Compiler
#if defined(TARGET_ARM64)
GenTree* gtNewSimdConvertVectorToMaskNode(var_types type, GenTree* node, CorInfoType simdBaseJitType, unsigned simdSize);
GenTree* gtNewSimdConvertMaskToVectorNode(GenTreeHWIntrinsic* node, var_types type);
GenTree* gtNewSimdEmbeddedMaskNode(CorInfoType simdBaseJitType, unsigned simdSize);
#endif

//------------------------------------------------------------------------
Expand Down
70 changes: 54 additions & 16 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1396,6 +1396,60 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
GenTree* op3 = nullptr;
GenTree* op4 = nullptr;

switch (numArgs)
{
case 4:
op4 = getArgForHWIntrinsic(sigReader.GetOp4Type(), sigReader.op4ClsHnd);
op4 = addRangeCheckIfNeeded(intrinsic, op4, mustExpand, immLowerBound, immUpperBound);
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

case 3:
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

case 2:
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

case 1:
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

default:
break;
}

#if defined(TARGET_ARM64)
// Embedded masks need inserting as op1.
if (HWIntrinsicInfo::IsEmbeddedMaskedOperation(intrinsic))
{
numArgs++;
assert(numArgs <= 4);
switch (numArgs)
{
case 4:
op4 = op3;
FALLTHROUGH;
case 3:
op3 = op2;
FALLTHROUGH;
case 2:
op2 = op1;
FALLTHROUGH;
default:
break;
}
op1 = gtNewSimdEmbeddedMaskNode(simdBaseJitType, simdSize);
}
#endif

switch (numArgs)
{
case 0:
Expand All @@ -1407,8 +1461,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 1:
{
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

if ((category == HW_Category_MemoryLoad) && op1->OperIs(GT_CAST))
{
// Although the API specifies a pointer, if what we have is a BYREF, that's what
Expand Down Expand Up @@ -1467,10 +1519,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 2:
{
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

retNode = isScalar
? gtNewScalarHWIntrinsicNode(nodeRetType, op1, op2, intrinsic)
: gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, intrinsic, simdBaseJitType, simdSize);
Expand Down Expand Up @@ -1524,10 +1572,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 3:
{
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

#ifdef TARGET_ARM64
if (intrinsic == NI_AdvSimd_LoadAndInsertScalar)
{
Expand Down Expand Up @@ -1569,12 +1613,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 4:
{
op4 = getArgForHWIntrinsic(sigReader.GetOp4Type(), sigReader.op4ClsHnd);
op4 = addRangeCheckIfNeeded(intrinsic, op4, mustExpand, immLowerBound, immUpperBound);
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

assert(!isScalar);
retNode =
gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, op3, op4, intrinsic, simdBaseJitType, simdSize);
Expand Down
11 changes: 10 additions & 1 deletion src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ enum HWIntrinsicFlag : unsigned int
// The intrinsic uses a mask in arg1 to select elements present in the result, and must use a low register.
HW_Flag_LowMaskedOperation = 0x40000,

// The intrinsic uses a mask in arg1 to select elements present in the result, which is not present in the API call
HW_Flag_EmbeddedMaskedOperation = 0x80000,

#else
#error Unsupported platform
#endif
Expand Down Expand Up @@ -872,7 +875,7 @@ struct HWIntrinsicInfo
static bool IsMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return ((flags & HW_Flag_MaskedOperation) != 0) || IsLowMaskedOperation(id);
return ((flags & HW_Flag_MaskedOperation) != 0) || IsLowMaskedOperation(id) || IsEmbeddedMaskedOperation(id);
}

static bool IsLowMaskedOperation(NamedIntrinsic id)
Expand All @@ -881,6 +884,12 @@ struct HWIntrinsicInfo
return (flags & HW_Flag_LowMaskedOperation) != 0;
}

static bool IsEmbeddedMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_EmbeddedMaskedOperation) != 0;
}

#endif // TARGET_ARM64

static bool HasSpecialSideEffect(NamedIntrinsic id)
Expand Down
17 changes: 16 additions & 1 deletion src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2222,7 +2222,7 @@ GenTree* Compiler::gtNewSimdConvertVectorToMaskNode(var_types type,
assert(varTypeIsSIMD(node));

// ConvertVectorToMask uses cmpne which requires an embedded mask.
GenTree* embeddedMask = gtNewSimdHWIntrinsicNode(TYP_MASK, NI_Sve_CreateTrueMaskAll, simdBaseJitType, simdSize);
GenTree* embeddedMask = gtNewSimdEmbeddedMaskNode(simdBaseJitType, simdSize);
return gtNewSimdHWIntrinsicNode(TYP_MASK, embeddedMask, node, NI_Sve_ConvertVectorToMask, simdBaseJitType,
simdSize);
}
Expand All @@ -2246,4 +2246,19 @@ GenTree* Compiler::gtNewSimdConvertMaskToVectorNode(GenTreeHWIntrinsic* node, va
node->GetSimdSize());
}

//------------------------------------------------------------------------
// gtNewSimdEmbeddedMaskNode: Create an embedded mask
//
// Arguments:
// simdBaseJitType -- the base jit type of the nodes being masked
// simdSize -- the simd size of the nodes being masked
//
// Return Value:
// The mask
//
GenTree* Compiler::gtNewSimdEmbeddedMaskNode(CorInfoType simdBaseJitType, unsigned simdSize)
{
return gtNewSimdHWIntrinsicNode(TYP_MASK, NI_Sve_CreateTrueMaskAll, simdBaseJitType, simdSize);
}

#endif // FEATURE_HW_INTRINSICS
81 changes: 58 additions & 23 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,64 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
unreached();
}
}
else if (isRMW)
{
assert(!hasImmediateOperand);
assert(!HWIntrinsicInfo::SupportsContainment(intrin.id));

// Move the RMW register out of the way and do not pass it to the emit.

if (HWIntrinsicInfo::IsEmbeddedMaskedOperation(intrin.id))
{
// op1Reg contains a mask, op2Reg contains the RMW register.

if (targetReg != op2Reg)
{
assert(targetReg != op3Reg);
GetEmitter()->emitIns_Mov(INS_mov, emitTypeSize(node), targetReg, op2Reg, /* canSkip */ true);
}

switch (intrin.numOperands)
{
case 2:
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt);
break;

case 3:
assert(targetReg != op3Reg);
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op3Reg, opt);
break;

default:
unreached();
}
}
else
{
// op1Reg contains the RMW register.

if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);
assert(targetReg != op3Reg);
GetEmitter()->emitIns_Mov(INS_mov, emitTypeSize(node), targetReg, op1Reg, /* canSkip */ true);
}

switch (intrin.numOperands)
{
case 2:
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op2Reg, opt);
break;

case 3:
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op2Reg, op3Reg, opt);
break;

default:
unreached();
}
}
}
else
{
assert(!hasImmediateOperand);
Expand All @@ -416,35 +474,12 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
{
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt);
}
else if (isRMW)
{
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);

GetEmitter()->emitIns_Mov(INS_mov, emitTypeSize(node), targetReg, op1Reg,
/* canSkip */ true);
}
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op2Reg, opt);
}
else
{
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
}
break;

case 3:
assert(isRMW);
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);
assert(targetReg != op3Reg);

GetEmitter()->emitIns_Mov(INS_mov, emitTypeSize(node), targetReg, op1Reg, /* canSkip */ true);
}
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op2Reg, op3Reg, opt);
break;

default:
unreached();
}
Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
// SVE Intrinsics

// Sve
HARDWARE_INTRINSIC(Sve, Abs, -1, -1, false, {INS_sve_abs, INS_invalid, INS_sve_abs, INS_invalid, INS_sve_abs, INS_invalid, INS_sve_abs, INS_invalid, INS_sve_fabs, INS_sve_fabs}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation)

HARDWARE_INTRINSIC(Sve, Add, -1, -1, false, {INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_fadd, INS_sve_fadd}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)

HARDWARE_INTRINSIC(Sve, CreateTrueMaskByte, -1, 1, false, {INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_EnumPattern, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskDouble, -1, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ptrue}, HW_Category_EnumPattern, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskInt16, -1, 1, false, {INS_invalid, INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_EnumPattern, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_ReturnsPerElementMask)
Expand Down
Loading

0 comments on commit 0d437e3

Please sign in to comment.