Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] Add lowering for @llvm.experimental.vector.compress #101015

Merged
merged 10 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2408,11 +2408,61 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS(SDNode *N, SDValue &Lo,
SDValue &Hi) {
// This is not "trivial", as there is a dependency between the two subvectors.
// Depending on the number of 1s in the mask, the elements from the Hi vector
// need to be moved to the Lo vector. So we just perform this as one "big"
// operation and then extract the Lo and Hi vectors from that. This gets rid
// of VECTOR_COMPRESS and all other operands can be legalized later.
SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, SDLoc(N));
// need to be moved to the Lo vector. Passthru values make this even harder.
// We try to use VECTOR_COMPRESS if the target has custom lowering with
// smaller types and passthru is undef, as it is most likely faster than the
// fully expand path. Otherwise, just do the full expansion as one "big"
// operation and then extract the Lo and Hi vectors from that. This gets
// rid of VECTOR_COMPRESS and all other operands can be legalized later.
SDLoc DL(N);
EVT VecVT = N->getValueType(0);

auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VecVT);
bool HasLegalOrCustom = false;
EVT CheckVT = LoVT;
while (CheckVT.getVectorMinNumElements() > 1) {
if (TLI.isOperationLegalOrCustom(ISD::VECTOR_COMPRESS, CheckVT)) {
HasLegalOrCustom = true;
break;
}
CheckVT = CheckVT.getHalfNumVectorElementsVT(*DAG.getContext());
}

SDValue Passthru = N->getOperand(2);
if (!HasLegalOrCustom || !Passthru.isUndef()) {
SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL, LoVT, HiVT);
return;
}

// Try to VECTOR_COMPRESS smaller vectors and combine via a stack store+load.
SDValue LoMask, HiMask;
std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
std::tie(LoMask, HiMask) = SplitMask(N->getOperand(1));

SDValue UndefPassthru = DAG.getUNDEF(LoVT);
Lo = DAG.getNode(ISD::VECTOR_COMPRESS, DL, LoVT, Lo, LoMask, UndefPassthru);
Hi = DAG.getNode(ISD::VECTOR_COMPRESS, DL, HiVT, Hi, HiMask, UndefPassthru);

SDValue StackPtr = DAG.CreateStackTemporary(
VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
MachineFunction &MF = DAG.getMachineFunction();
MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());

// We store LoVec and then insert HiVec starting at offset=|1s| in LoMask.
SDValue WideMask =
DAG.getNode(ISD::ZERO_EXTEND, DL, LoMask.getValueType(), LoMask);
SDValue Offset = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, WideMask);
Offset = TLI.getVectorElementPointer(DAG, StackPtr, VecVT, Offset);

SDValue Chain = DAG.getEntryNode();
Chain = DAG.getStore(Chain, DL, Lo, StackPtr, PtrInfo);
Chain = DAG.getStore(Chain, DL, Hi, Offset,
MachinePointerInfo::getUnknownStack(MF));

SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL);
}

void DAGTypeLegalizer::SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
Expand Down Expand Up @@ -5784,7 +5834,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_VECTOR_COMPRESS(SDNode *N) {
TLI.getTypeToTransformTo(*DAG.getContext(), Vec.getValueType());
EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
Mask.getValueType().getVectorElementType(),
WideVecVT.getVectorNumElements());
WideVecVT.getVectorElementCount());

SDValue WideVec = ModifyToType(Vec, WideVecVT);
SDValue WideMask = ModifyToType(Mask, WideMaskVT, /*FillWithZeroes=*/true);
Expand Down
215 changes: 215 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,24 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
}
}

// We can lower types that have <vscale x {2|4}> elements to svcompact and
davemgreen marked this conversation as resolved.
Show resolved Hide resolved
// legal i8/i16 types via a compressing store.
for (auto VT :
lawben marked this conversation as resolved.
Show resolved Hide resolved
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32,
MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
if (Subtarget->hasSVE())
for (auto VT :
{MVT::v1i8, MVT::v1i16, MVT::v1i32, MVT::v1i64, MVT::v1f32,
MVT::v1f64, MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64,
MVT::v2f32, MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32,
MVT::v4f32, MVT::v8i8, MVT::v8i16, MVT::v8i16, MVT::v16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// NEON doesn't support masked loads/stores, but SME and SVE do.
for (auto VT :
{MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, MVT::v1f64,
Expand Down Expand Up @@ -6615,6 +6633,132 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
return DAG.getMergeValues({Ext, Chain}, DL);
}

SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
SDValue Vec = Op.getOperand(0);
SDValue Mask = Op.getOperand(1);
SDValue Passthru = Op.getOperand(2);
EVT VecVT = Vec.getValueType();
EVT MaskVT = Mask.getValueType();
EVT ElmtVT = VecVT.getVectorElementType();
const bool IsFixedLength = VecVT.isFixedLengthVector();
const bool HasPassthru = !Passthru.isUndef();
unsigned MinElmts = VecVT.getVectorElementCount().getKnownMinValue();
EVT FixedVecVT = MVT::getVectorVT(ElmtVT.getSimpleVT(), MinElmts);

assert(VecVT.isVector() && "Input to VECTOR_COMPRESS must be vector.");

if (!Subtarget->hasSVE())
lawben marked this conversation as resolved.
Show resolved Hide resolved
return SDValue();

if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
return SDValue();

// We can use the SVE register containing the NEON vector in its lowest bits.
if (IsFixedLength) {
EVT ScalableVecVT =
MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), MinElmts);
EVT ScalableMaskVT = MVT::getScalableVectorVT(
MaskVT.getVectorElementType().getSimpleVT(), MinElmts);

Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
DAG.getUNDEF(ScalableVecVT), Vec,
DAG.getConstant(0, DL, MVT::i64));
Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableMaskVT,
DAG.getUNDEF(ScalableMaskVT), Mask,
DAG.getConstant(0, DL, MVT::i64));
Mask = DAG.getNode(ISD::TRUNCATE, DL,
ScalableMaskVT.changeVectorElementType(MVT::i1), Mask);
Passthru = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
DAG.getUNDEF(ScalableVecVT), Passthru,
DAG.getConstant(0, DL, MVT::i64));

VecVT = Vec.getValueType();
MaskVT = Mask.getValueType();
}

// Special case where we can't use svcompact but can do a compressing store
// and then reload the vector.
if (VecVT == MVT::nxv8i8 || VecVT == MVT::nxv16i8 || VecVT == MVT::nxv8i16) {
SDValue StackPtr = DAG.CreateStackTemporary(VecVT);
int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
MachinePointerInfo PtrInfo =
MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);

MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
PtrInfo, MachineMemOperand::Flags::MOStore,
LocationSize::precise(VecVT.getStoreSize()),
DAG.getReducedAlign(VecVT, /*UseABI=*/false));

SDValue Chain = DAG.getEntryNode();
if (HasPassthru)
Chain = DAG.getStore(Chain, DL, Passthru, StackPtr, PtrInfo);

Chain = DAG.getMaskedStore(Chain, DL, Vec, StackPtr, DAG.getUNDEF(MVT::i64),
lawben marked this conversation as resolved.
Show resolved Hide resolved
Mask, VecVT, MMO, ISD::UNINDEXED,
lawben marked this conversation as resolved.
Show resolved Hide resolved
/*IsTruncating=*/false, /*IsCompressing=*/true);

lawben marked this conversation as resolved.
Show resolved Hide resolved
SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);

if (IsFixedLength)
Compressed = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FixedVecVT,
Compressed, DAG.getConstant(0, DL, MVT::i64));

return Compressed;
}

// Only <vscale x {2|4} x {i32|i64}> supported for svcompact.
if (MinElmts != 2 && MinElmts != 4)
return SDValue();

// Get legal type for svcompact instruction
EVT ContainerVT = getSVEContainerType(VecVT);
EVT CastVT = VecVT.changeVectorElementTypeToInteger();

// Convert to i32 or i64 for smaller types, as these are the only supported
// sizes for svcompact.
if (ContainerVT != VecVT) {
Vec = DAG.getBitcast(CastVT, Vec);
Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
}

SDValue Compressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);

// svcompact fills with 0s, so if our passthru is all 0s, do nothing here.
if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
SDValue Offset = DAG.getNode(
ISD::ZERO_EXTEND, DL, MaskVT.changeVectorElementType(MVT::i32), Mask);
Offset = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, Offset);
Compressed =
DAG.getNode(ISD::VP_MERGE, DL, VecVT,
Copy link
Collaborator

@davemgreen davemgreen Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VP_MERGE are not really supported or encouraged by the AArch64 backend. Is there an alternative we can emit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would be the AArch64-way to express this logic? I copied it from a RICSV example that came up in the original discussion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a passthru vector is given, all remaining lanes are filled with the
corresponding lane's value from passthru.

What does 'corresponding lane' mean in this case? If the mask is <1, 0, 0, 1>, would the passthru for the zero'ed lanes expected to be <_, _, p, p> or <_, p, p, _> (where 'p' means the passthru value and '_' for don't care)

If the former, then I guess you could do a popcount of the predicate, create a mask from that, and then do a vector select?

This also makes me wonder, would it be better to define the intrinsic to make the other lanes undefined, rather than adding a passthru parameter to the intrinsic? That would make the operation easier to codegen, and we can use existing intrinsics to implement the passthru behaviour.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The passthru exists because it's useful for some combinations of target/passthru value. For SVE in particular, for a non-zero passthru, we need to explicitly construct a mask, but other targets support it directly. This was discussed in #92289.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sdesmalen-arm The "corresponding lanes" are the remainder, used to fill up empty slots in the output. So vec=<a, b, c, d>, mask=<1, 0, 0, 1>, passthru=<w, x, y, z> would result in <a, d, y, z>. In your example, it would be <_, _, p, p>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davemgreen using VP_MERGE results in the following assembly (in one of the tests)

cntp	x8, p0, p0.s
index	z2.s, #0, #1
compact	z0.s, p0, z0.s
ptrue	p1.s
mov	z3.s, w8
cmphi	p1.s, p1/z, z3.s, z2.s
sel	z0.s, p1, z0.s, z1.s

So it is doing the right thing. I could manually add the instructions instead of using VP_MERGE, but I'm not sure that makes sense. What would you suggest to do here?

DAG.getSplatVector(MaskVT, DL,
DAG.getAllOnesConstant(
DL, MaskVT.getVectorElementType())),
Compressed, Passthru, Offset);
}

// Extracting from a legal SVE type before truncating produces better code.
if (IsFixedLength) {
Compressed = DAG.getNode(
ISD::EXTRACT_SUBVECTOR, DL,
FixedVecVT.changeVectorElementType(ContainerVT.getVectorElementType()),
Compressed, DAG.getConstant(0, DL, MVT::i64));
CastVT = FixedVecVT.changeVectorElementTypeToInteger();
VecVT = FixedVecVT;
}

// If we changed the element type before, we need to convert it back.
if (ContainerVT != VecVT) {
Compressed = DAG.getNode(ISD::TRUNCATE, DL, CastVT, Compressed);
Compressed = DAG.getBitcast(VecVT, Compressed);
}

return Compressed;
}

// Generate SUBS and CSEL for integer abs.
SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
MVT VT = Op.getSimpleValueType();
Expand Down Expand Up @@ -6995,6 +7139,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerDYNAMIC_STACKALLOC(Op, DAG);
case ISD::VSCALE:
return LowerVSCALE(Op, DAG);
case ISD::VECTOR_COMPRESS:
return LowerVECTOR_COMPRESS(Op, DAG);
case ISD::ANY_EXTEND:
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
Expand Down Expand Up @@ -22928,6 +23074,68 @@ static SDValue combineI8TruncStore(StoreSDNode *ST, SelectionDAG &DAG,
return Chain;
}

static SDValue combineVECTOR_COMPRESSStore(SelectionDAG &DAG,
StoreSDNode *Store,
const AArch64Subtarget *Subtarget) {
// If the regular store is preceded by an VECTOR_COMPRESS, we can combine them
// into a compressing store for scalable vectors in SVE.
SDValue VecOp = Store->getValue();
EVT VecVT = VecOp.getValueType();
if (VecOp.getOpcode() != ISD::VECTOR_COMPRESS || !Subtarget->hasSVE())
return SDValue();

bool IsFixedLength = VecVT.isFixedLengthVector();
if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
return SDValue();

SDLoc DL(Store);
SDValue Vec = VecOp.getOperand(0);
SDValue Mask = VecOp.getOperand(1);
SDValue Passthru = VecOp.getOperand(2);
EVT MemVT = Store->getMemoryVT();
MachineMemOperand *MMO = Store->getMemOperand();
SDValue Chain = Store->getChain();

// We can use the SVE register containing the NEON vector in its lowest bits.
if (IsFixedLength) {
EVT ElmtVT = VecVT.getVectorElementType();
unsigned NumElmts = VecVT.getVectorNumElements();
EVT ScalableVecVT =
MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), NumElmts);
EVT ScalableMaskVT = MVT::getScalableVectorVT(
Mask.getValueType().getVectorElementType().getSimpleVT(), NumElmts);

Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
DAG.getUNDEF(ScalableVecVT), Vec,
DAG.getConstant(0, DL, MVT::i64));
Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableMaskVT,
DAG.getUNDEF(ScalableMaskVT), Mask,
DAG.getConstant(0, DL, MVT::i64));
Mask = DAG.getNode(ISD::TRUNCATE, DL,
ScalableMaskVT.changeVectorElementType(MVT::i1), Mask);
Passthru = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
DAG.getUNDEF(ScalableVecVT), Passthru,
DAG.getConstant(0, DL, MVT::i64));

MemVT = ScalableVecVT;
MMO->setType(LLT::scalable_vector(NumElmts, ElmtVT.getSizeInBits()));
}

// If the passthru is all 0s, we don't need an explicit passthru store.
unsigned MinElmts = VecVT.getVectorMinNumElements();
if (ISD::isConstantSplatVectorAllZeros(Passthru.getNode()) &&
(MinElmts == 2 || MinElmts == 4))
return SDValue();

if (!Passthru.isUndef())
Chain = DAG.getStore(Chain, DL, Passthru, Store->getBasePtr(), MMO);

return DAG.getMaskedStore(Chain, DL, Vec, Store->getBasePtr(),
DAG.getUNDEF(MVT::i64), Mask, MemVT, MMO,
ISD::UNINDEXED, Store->isTruncatingStore(),
/*IsCompressing=*/true);
}

static SDValue performSTORECombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG,
Expand Down Expand Up @@ -22972,6 +23180,9 @@ static SDValue performSTORECombine(SDNode *N,
if (SDValue Store = combineBoolVectorAndTruncateStore(DAG, ST))
return Store;

if (SDValue Store = combineVECTOR_COMPRESSStore(DAG, ST, Subtarget))
return Store;

if (ST->isTruncatingStore()) {
EVT StoreVT = ST->getMemoryVT();
if (!isHalvingTruncateOfLegalScalableType(ValueVT, StoreVT))
Expand Down Expand Up @@ -26214,6 +26425,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
case ISD::VECREDUCE_UMIN:
Results.push_back(LowerVECREDUCE(SDValue(N, 0), DAG));
return;
case ISD::VECTOR_COMPRESS:
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
case ISD::ADD:
case ISD::FADD:
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,8 @@ class AArch64TargetLowering : public TargetLowering {

SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerVECTOR_COMPRESS(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINTRINSIC_VOID(SDValue Op, SelectionDAG &DAG) const;
Expand Down
Loading
Loading