Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[x86] Add lowering for @llvm.experimental.vector.compress #104904

Merged
merged 9 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SplitVecOp_VP_REDUCE(SDNode *N, unsigned OpNo);
SDValue SplitVecOp_UnaryOp(SDNode *N);
SDValue SplitVecOp_TruncateHelper(SDNode *N);
SDValue SplitVecOp_VECTOR_COMPRESS(SDNode *N, unsigned OpNo);

SDValue SplitVecOp_BITCAST(SDNode *N);
SDValue SplitVecOp_INSERT_SUBVECTOR(SDNode *N, unsigned OpNo);
Expand Down
26 changes: 24 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2436,16 +2436,17 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS(SDNode *N, SDValue &Lo,
}

SDValue Passthru = N->getOperand(2);
if (!HasCustomLowering || !Passthru.isUndef()) {
if (!HasCustomLowering) {
SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL, LoVT, HiVT);
return;
}

// Try to VECTOR_COMPRESS smaller vectors and combine via a stack store+load.
SDValue Mask = N->getOperand(1);
SDValue LoMask, HiMask;
std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
std::tie(LoMask, HiMask) = SplitMask(N->getOperand(1));
std::tie(LoMask, HiMask) = SplitMask(Mask);

SDValue UndefPassthru = DAG.getUNDEF(LoVT);
Lo = DAG.getNode(ISD::VECTOR_COMPRESS, DL, LoVT, Lo, LoMask, UndefPassthru);
Expand All @@ -2469,6 +2470,10 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS(SDNode *N, SDValue &Lo,
MachinePointerInfo::getUnknownStack(MF));

SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
if (!Passthru.isUndef()) {
Compressed =
DAG.getNode(ISD::VSELECT, DL, VecVT, Mask, Compressed, Passthru);
}
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL);
}

Expand Down Expand Up @@ -3226,6 +3231,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VSELECT:
Res = SplitVecOp_VSELECT(N, OpNo);
break;
case ISD::VECTOR_COMPRESS:
Res = SplitVecOp_VECTOR_COMPRESS(N, OpNo);
break;
case ISD::STRICT_SINT_TO_FP:
case ISD::STRICT_UINT_TO_FP:
case ISD::SINT_TO_FP:
Expand Down Expand Up @@ -3372,6 +3380,20 @@ SDValue DAGTypeLegalizer::SplitVecOp_VSELECT(SDNode *N, unsigned OpNo) {
return DAG.getNode(ISD::CONCAT_VECTORS, DL, Src0VT, LoSelect, HiSelect);
}

SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_COMPRESS(SDNode *N, unsigned OpNo) {
// The only possibility for an illegal operand is the mask, since result type
// legalization would have handled this node already otherwise.
assert(OpNo == 1 && "Illegal operand must be mask");

// To split the mask, we need to split the result type too, so we can just
// reuse that logic here.
SDValue Lo, Hi;
SplitVecRes_VECTOR_COMPRESS(N, Lo, Hi);

EVT VecVT = N->getValueType(0);
return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VecVT, Lo, Hi);
}

SDValue DAGTypeLegalizer::SplitVecOp_VECREDUCE(SDNode *N, unsigned OpNo) {
EVT ResVT = N->getValueType(0);
SDValue Lo, Hi;
Expand Down
14 changes: 9 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11582,11 +11582,13 @@ SDValue TargetLowering::expandVECTOR_COMPRESS(SDNode *Node,
// ... if it is not a splat vector, we need to get the passthru value at
// position = popcount(mask) and re-load it from the stack before it is
// overwritten in the loop below.
EVT PopcountVT = ScalarVT.changeTypeToInteger();
SDValue Popcount = DAG.getNode(
ISD::TRUNCATE, DL, MaskVT.changeVectorElementType(MVT::i1), Mask);
Popcount = DAG.getNode(ISD::ZERO_EXTEND, DL,
MaskVT.changeVectorElementType(ScalarVT), Popcount);
Popcount = DAG.getNode(ISD::VECREDUCE_ADD, DL, ScalarVT, Popcount);
Popcount =
DAG.getNode(ISD::ZERO_EXTEND, DL,
MaskVT.changeVectorElementType(PopcountVT), Popcount);
Popcount = DAG.getNode(ISD::VECREDUCE_ADD, DL, PopcountVT, Popcount);
SDValue LastElmtPtr =
getVectorElementPointer(DAG, StackPtr, VecVT, Popcount);
LastWriteVal = DAG.getLoad(
Expand Down Expand Up @@ -11625,8 +11627,10 @@ SDValue TargetLowering::expandVECTOR_COMPRESS(SDNode *Node,

// Re-write the last ValI if all lanes were selected. Otherwise,
// overwrite the last write it with the passthru value.
LastWriteVal =
DAG.getSelect(DL, ScalarVT, AllLanesSelected, ValI, LastWriteVal);
SDNodeFlags Flags{};
Flags.setUnpredictable(true);
LastWriteVal = DAG.getSelect(DL, ScalarVT, AllLanesSelected, ValI,
LastWriteVal, Flags);
Chain = DAG.getStore(
Chain, DL, LastWriteVal, OutPtr,
MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
Expand Down
92 changes: 92 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2125,6 +2125,35 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64 })
setOperationAction(ISD::CTPOP, VT, Legal);
}

// We can try to convert vectors to different sizes to leverage legal
// `vpcompress` cases. So we mark these supported vector sizes as Custom and
// then specialize to Legal below.
for (MVT VT : {MVT::v8i32, MVT::v8f32, MVT::v4i32, MVT::v4f32, MVT::v4i64,
MVT::v4f64, MVT::v2i64, MVT::v2f64, MVT::v16i8, MVT::v8i16,
MVT::v16i16, MVT::v8i8})
lawben marked this conversation as resolved.
Show resolved Hide resolved
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// Legal vpcompress depends on various AVX512 extensions.
// Legal in AVX512F
for (MVT VT : {MVT::v16i32, MVT::v16f32, MVT::v8i64, MVT::v8f64})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Legal);

// Legal in AVX512F + AVX512VL
if (Subtarget.hasVLX())
for (MVT VT : {MVT::v8i32, MVT::v8f32, MVT::v4i32, MVT::v4f32, MVT::v4i64,
MVT::v4f64, MVT::v2i64, MVT::v2f64})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Legal);

// Legal in AVX512F + AVX512VBMI2
if (Subtarget.hasVBMI2())
for (MVT VT : {MVT::v32i16, MVT::v64i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Legal);

// Legal in AVX512F + AVX512VL + AVX512VBMI2
if (Subtarget.hasVBMI2() && Subtarget.hasVLX())
for (MVT VT : {MVT::v16i8, MVT::v8i16, MVT::v32i8, MVT::v16i16})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Legal);
}

// This block control legalization of v32i1/v64i1 which are available with
Expand Down Expand Up @@ -17755,6 +17784,68 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, const X86Subtarget &Subtarget,
llvm_unreachable("Unimplemented!");
}

// As legal vpcompress instructions depend on various AVX512 extensions, try to
// convert illegal vector sizes to legal ones to avoid expansion.
static SDValue lowerVECTOR_COMPRESS(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
assert(Subtarget.hasAVX512() &&
"Need AVX512 for custom VECTOR_COMPRESS lowering.");

SDLoc DL(Op);
SDValue Vec = Op.getOperand(0);
SDValue Mask = Op.getOperand(1);
SDValue Passthru = Op.getOperand(2);

EVT VecVT = Vec.getValueType();
EVT ElementVT = VecVT.getVectorElementType();
unsigned NumElements = VecVT.getVectorNumElements();
unsigned NumVecBits = VecVT.getFixedSizeInBits();
unsigned NumElementBits = ElementVT.getFixedSizeInBits();

// 128- and 256-bit vectors with <= 16 elements can be converted to and
// compressed as 512-bit vectors in AVX512F.
if (NumVecBits != 128 && NumVecBits != 256)
return SDValue();

if (NumElementBits == 32 || NumElementBits == 64) {
unsigned NumLargeElements = 512 / NumElementBits;
MVT LargeVecVT =
MVT::getVectorVT(ElementVT.getSimpleVT(), NumLargeElements);
MVT LargeMaskVT = MVT::getVectorVT(MVT::i1, NumLargeElements);

Vec = widenSubVector(LargeVecVT, Vec, /*ZeroNewElements=*/false, Subtarget,
DAG, DL);
Mask = widenSubVector(LargeMaskVT, Mask, /*ZeroNewElements=*/true,
Subtarget, DAG, DL);
Passthru = Passthru.isUndef() ? DAG.getUNDEF(LargeVecVT)
: widenSubVector(LargeVecVT, Passthru,
/*ZeroNewElements=*/false,
lawben marked this conversation as resolved.
Show resolved Hide resolved
Subtarget, DAG, DL);

SDValue Compressed =
DAG.getNode(ISD::VECTOR_COMPRESS, DL, LargeVecVT, Vec, Mask, Passthru);
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, Compressed,
DAG.getConstant(0, DL, MVT::i64));
}

if (VecVT == MVT::v8i16 || VecVT == MVT::v8i8 || VecVT == MVT::v16i8 ||
VecVT == MVT::v16i16) {
MVT LageElementVT = MVT::getIntegerVT(512 / NumElements);
EVT LargeVecVT = MVT::getVectorVT(LageElementVT, NumElements);

Vec = DAG.getNode(ISD::ANY_EXTEND, DL, LargeVecVT, Vec);
Passthru = Passthru.isUndef()
? DAG.getUNDEF(LargeVecVT)
: DAG.getNode(ISD::ANY_EXTEND, DL, LargeVecVT, Passthru);

SDValue Compressed =
DAG.getNode(ISD::VECTOR_COMPRESS, DL, LargeVecVT, Vec, Mask, Passthru);
return DAG.getNode(ISD::TRUNCATE, DL, VecVT, Compressed);
}

return SDValue();
}

/// Try to lower a VSELECT instruction to a vector shuffle.
static SDValue lowerVSELECTtoVectorShuffle(SDValue Op,
const X86Subtarget &Subtarget,
Expand Down Expand Up @@ -32374,6 +32465,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::BUILD_VECTOR: return LowerBUILD_VECTOR(Op, DAG);
case ISD::CONCAT_VECTORS: return LowerCONCAT_VECTORS(Op, Subtarget, DAG);
case ISD::VECTOR_SHUFFLE: return lowerVECTOR_SHUFFLE(Op, Subtarget, DAG);
case ISD::VECTOR_COMPRESS: return lowerVECTOR_COMPRESS(Op, Subtarget, DAG);
case ISD::VSELECT: return LowerVSELECT(Op, DAG);
case ISD::EXTRACT_VECTOR_ELT: return LowerEXTRACT_VECTOR_ELT(Op, DAG);
case ISD::INSERT_VECTOR_ELT: return LowerINSERT_VECTOR_ELT(Op, DAG);
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/X86/X86InstrAVX512.td
Original file line number Diff line number Diff line change
Expand Up @@ -10543,6 +10543,12 @@ multiclass compress_by_vec_width_lowering<X86VectorVTInfo _, string Name> {
def : Pat<(X86compress (_.VT _.RC:$src), _.ImmAllZerosV, _.KRCWM:$mask),
(!cast<Instruction>(Name#_.ZSuffix#rrkz)
_.KRCWM:$mask, _.RC:$src)>;
def : Pat<(_.VT (vector_compress _.RC:$src, _.KRCWM:$mask, undef)),
(!cast<Instruction>(Name#_.ZSuffix#rrkz)
_.KRCWM:$mask, _.RC:$src)>;
def : Pat<(_.VT (vector_compress _.RC:$src, _.KRCWM:$mask, _.RC:$passthru)),
(!cast<Instruction>(Name#_.ZSuffix#rrk)
_.RC:$passthru, _.KRCWM:$mask, _.RC:$src)>;
}

multiclass compress_by_elt_width<bits<8> opc, string OpcodeStr,
Expand Down
Loading
Loading