-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Arm64/Sve: Implement SVE Math *Multiply* APIs #102007
Changes from 6 commits
97373ca
4e14098
3fb9dea
600391a
67e4d4d
54899b2
e4a53ae
bfad7b7
100f289
8ac1840
9eb195e
c182d0d
62ea159
28a49cb
722dd55
229f78f
a21439f
6a01ca4
318cbf3
e3fc830
1ca5539
eb41e1d
7874f25
f756afb
2904934
53d29a0
98ac0ce
0f89e10
8e928ec
c713d31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -417,10 +417,16 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) | |
regNumber maskReg = op1Reg; | ||
regNumber embMaskOp1Reg = REG_NA; | ||
regNumber embMaskOp2Reg = REG_NA; | ||
regNumber embMaskOp3Reg = REG_NA; | ||
regNumber falseReg = op3Reg; | ||
|
||
switch (intrinEmbMask.numOperands) | ||
{ | ||
case 3: | ||
assert(intrinEmbMask.op3 != nullptr); | ||
embMaskOp3Reg = intrinEmbMask.op3->GetRegNum(); | ||
FALLTHROUGH; | ||
|
||
case 2: | ||
assert(intrinEmbMask.op2 != nullptr); | ||
embMaskOp2Reg = intrinEmbMask.op2->GetRegNum(); | ||
|
@@ -438,6 +444,7 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) | |
switch (intrinEmbMask.numOperands) | ||
{ | ||
case 1: | ||
{ | ||
assert(!instrIsRMW); | ||
|
||
if (targetReg != falseReg) | ||
|
@@ -488,9 +495,10 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) | |
|
||
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp1Reg, opt); | ||
break; | ||
} | ||
|
||
case 2: | ||
|
||
{ | ||
assert(instrIsRMW); | ||
|
||
if (intrin.op3->IsVectorZero()) | ||
|
@@ -560,7 +568,50 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) | |
} | ||
|
||
break; | ||
} | ||
case 3: | ||
{ | ||
assert(instrIsRMW); | ||
assert(targetReg != falseReg); | ||
assert(targetReg != embMaskOp2Reg); | ||
assert(targetReg != embMaskOp3Reg); | ||
assert(!HWIntrinsicInfo::IsOptionalEmbeddedMaskedOperation(intrinEmbMask.id)); | ||
|
||
if (intrin.op3->IsVectorZero()) | ||
{ | ||
// If `falseReg` is zero, then move the first operand of `intrinEmbMask` in the | ||
// destination using /Z. | ||
|
||
assert(targetReg != embMaskOp2Reg); | ||
GetEmitter()->emitIns_R_R_R(INS_sve_movprfx, emitSize, targetReg, maskReg, embMaskOp1Reg, opt); | ||
|
||
// Finally, perform the actual "predicated" operation so that `targetReg` is the first operand | ||
// `embMaskOp2Reg` is the second operand and `embMaskOp3Reg` is the third operand. | ||
GetEmitter()->emitIns_R_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg, | ||
embMaskOp3Reg, opt); | ||
} | ||
else | ||
{ | ||
// If the instruction just has "predicated" version, then move the "embMaskOp1Reg" | ||
// into targetReg. Next, do the predicated operation on the targetReg and last, | ||
// use "sel" to select the active lanes based on mask, and set inactive lanes | ||
// to falseReg. | ||
|
||
assert(HWIntrinsicInfo::IsEmbeddedMaskedOperation(intrinEmbMask.id)); | ||
|
||
if (targetReg != embMaskOp1Reg) | ||
{ | ||
GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, embMaskOp1Reg); | ||
} | ||
|
||
GetEmitter()->emitIns_R_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg, | ||
embMaskOp3Reg, opt); | ||
|
||
GetEmitter()->emitIns_R_R_R_R(INS_sve_sel, emitSize, targetReg, maskReg, targetReg, falseReg, | ||
opt, INS_SCALABLE_OPTS_UNPREDICATED); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there an assumption being made about the instruction being RMW here?
Given some fmla Zda, Pg/M, Zn, Zm Given some movprfx Zda, Pg/M, merge
fmla Zda, Pg/M, Zn, Zm Given some movprfx Zda, Pg/Z, Zda
fmla Zda, Pg/M, Zn, Zm There are then similar versions possible using We should actually never need mov dest, Zda
movprfx dest, Pg/M, merge
fmla dest, Pg/M, Zn, Zm This ends up being different from the other fallbacks that do use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main goal of using In this case, we at worst need a 3 instruction sequence due to the required predication on the instruction. Thus, it becomes better to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For the similar reasoning mentioned in #100743 (comment) (where we should only mov dest, Zda
fmla dest, Pg/M, Zn, Zm
sel dest, Pg/M, dest, merge
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I misinterpreted the value of sel dest, Pg/M, Zda, merge
fmla dest, Pg/M, Zn, Zm |
||
} | ||
break; | ||
} | ||
default: | ||
unreached(); | ||
} | ||
|
@@ -627,6 +678,12 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) | |
INS_SCALABLE_OPTS_UNPREDICATED); | ||
} | ||
break; | ||
case 4: | ||
kunalspathak marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert(!isRMW); | ||
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, op3Reg, opt, | ||
INS_SCALABLE_OPTS_UNPREDICATED); | ||
break; | ||
|
||
default: | ||
unreached(); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,12 @@ HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask32Bit, | |
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask64Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask) | ||
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask8Bit, -1, 2, false, {INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask) | ||
HARDWARE_INTRINSIC(Sve, Divide, -1, 2, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sdiv, INS_sve_udiv, INS_sve_sdiv, INS_sve_udiv, INS_sve_fdiv, INS_sve_fdiv}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation) | ||
HARDWARE_INTRINSIC(Sve, FusedMultiplyAdd, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmla, INS_sve_fmla}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are always using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, I am just preferencing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, that sounds reasonable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would definitely expect us to have some logic around picking FMLA vs FMAD. The x64 logic is even more complex because it has to handle the RMW consideration (should the x64 then repeats this logic again in LSRA to actually set the I expect that Arm64 just needs to mirror the LSRA and codegen logic (ignoring any bits relevant to containment) and picking |
||
HARDWARE_INTRINSIC(Sve, FusedMultiplyAddBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmla, INS_sve_fmla}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics) | ||
HARDWARE_INTRINSIC(Sve, FusedMultiplyAddNegated, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fnmla, INS_sve_fnmla}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation) | ||
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtract, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation) | ||
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics) | ||
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractNegated, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fnmls, INS_sve_fnmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation) | ||
HARDWARE_INTRINSIC(Sve, LoadVector, -1, 2, true, {INS_sve_ld1b, INS_sve_ld1b, INS_sve_ld1h, INS_sve_ld1h, INS_sve_ld1w, INS_sve_ld1w, INS_sve_ld1d, INS_sve_ld1d, INS_sve_ld1w, INS_sve_ld1d}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation) | ||
HARDWARE_INTRINSIC(Sve, LoadVectorByteZeroExtendToInt16, -1, 2, false, {INS_invalid, INS_invalid, INS_sve_ld1b, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation) | ||
HARDWARE_INTRINSIC(Sve, LoadVectorByteZeroExtendToInt32, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ld1b, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1772,7 +1772,7 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// then record delay-free for operands as well as the "merge" value | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
GenTreeHWIntrinsic* intrinEmbOp2 = intrin.op2->AsHWIntrinsic(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t numArgs = intrinEmbOp2->GetOperandCount(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert((numArgs == 1) || (numArgs == 2)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert((numArgs == 1) || (numArgs == 2) || (numArgs == 3)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tgtPrefUse = BuildUse(intrinEmbOp2->Op(1)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
srcCount += 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1792,7 +1792,8 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert(intrin.op1 != nullptr); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
bool forceOp2DelayFree = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
bool forceOp2DelayFree = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
regMaskTP candidates = RBM_NONE; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ((intrin.id == NI_Vector64_GetElement) || (intrin.id == NI_Vector128_GetElement)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!intrin.op2->IsCnsIntOrI() && (!intrin.op1->isContained() || intrin.op1->OperIsLocal())) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1815,6 +1816,22 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ((intrin.id == NI_Sve_FusedMultiplyAddBySelectedScalar) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do these this require special code here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because as per FMLA (indexed), We have similar code for AdvSimd too and most likely, if I see more patterns in future, I will combine this code with it. runtime/src/coreclr/jit/lsraarm64.cpp Lines 1586 to 1613 in 3fce4e7
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
(intrin.id == NI_Sve_FusedMultiplySubtractBySelectedScalar)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// If this is common pattern, then we will add a flag in the table, but for now, just check for specific | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// intrinsics | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (intrin.baseType == TYP_DOUBLE) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
candidates = RBM_SVE_INDEXED_D_ELEMENT_ALLOWED_REGS; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert(intrin.baseType == TYP_FLOAT); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
candidates = RBM_SVE_INDEXED_S_ELEMENT_ALLOWED_REGS; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ((intrin.id == NI_Sve_ConditionalSelect) && (intrin.op2->IsEmbMaskOp()) && | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
(intrin.op2->isRMWHWIntrinsic(compiler))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1845,7 +1862,8 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (intrin.op3 != nullptr) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
srcCount += isRMW ? BuildDelayFreeUses(intrin.op3, intrin.op1) : BuildOperandUses(intrin.op3); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
srcCount += isRMW ? BuildDelayFreeUses(intrin.op3, intrin.op1, candidates) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
: BuildOperandUses(intrin.op3, candidates); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (intrin.op4 != nullptr) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be asserting that
intrin.op3
is contained?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added `