Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] Add lowering for @llvm.experimental.vector.compress #101015

Merged
merged 10 commits into from
Aug 13, 2024

Conversation

lawben
Copy link
Contributor

@lawben lawben commented Jul 29, 2024

This is a follow-up to #92289 that adds custom lowering of the new @llvm.experimental.vector.compress intrinsic on AArch64 with SVE instructions.

Some vectors have a compact instruction that they can be lowered to.

@lawben lawben requested review from RKSimon and efriedma-quic July 29, 2024 14:10
@llvmbot llvmbot added backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well labels Jul 29, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Lawrence Benson (lawben)

Changes

This is a follow-up to #92289 that adds custom lowering of the new @<!-- -->llvm.experimental.vector.compress intrinsic on AArch64 with SVE instructions.

Some vectors have a compact instruction that they can be lowered to. For other cases, we use SVE's st1b.

TODO: I still need to run this on an SVE machine.


Patch is 38.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101015.diff

4 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+56-6)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+214)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2)
  • (added) llvm/test/CodeGen/AArch64/sve-vector-compress.ll (+496)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 5672b611234b8..b42a54a56cfed 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -2408,11 +2408,61 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS(SDNode *N, SDValue &Lo,
                                                    SDValue &Hi) {
   // This is not "trivial", as there is a dependency between the two subvectors.
   // Depending on the number of 1s in the mask, the elements from the Hi vector
-  // need to be moved to the Lo vector. So we just perform this as one "big"
-  // operation and then extract the Lo and Hi vectors from that. This gets rid
-  // of VECTOR_COMPRESS and all other operands can be legalized later.
-  SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
-  std::tie(Lo, Hi) = DAG.SplitVector(Compressed, SDLoc(N));
+  // need to be moved to the Lo vector. Passthru values make this even harder.
+  // We try to use VECTOR_COMPRESS if the target has custom lowering with
+  // smaller types and passthru is undef, as it is most likely faster than the
+  // fully expand path. Otherwise, just do the full expansion as one "big"
+  // operation and then extract the Lo and Hi vectors from that. This gets
+  // rid of VECTOR_COMPRESS and all other operands can be legalized later.
+  SDLoc DL(N);
+  EVT VecVT = N->getValueType(0);
+
+  auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VecVT);
+  bool HasLegalOrCustom = false;
+  EVT CheckVT = LoVT;
+  while (CheckVT.getVectorMinNumElements() > 1) {
+    if (TLI.isOperationLegalOrCustom(ISD::VECTOR_COMPRESS, CheckVT)) {
+      HasLegalOrCustom = true;
+      break;
+    }
+    CheckVT = CheckVT.getHalfNumVectorElementsVT(*DAG.getContext());
+  }
+
+  SDValue Passthru = N->getOperand(2);
+  if (!HasLegalOrCustom || !Passthru.isUndef()) {
+    SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
+    std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL, LoVT, HiVT);
+    return;
+  }
+
+  // Try to VECTOR_COMPRESS smaller vectors and combine via a stack store+load.
+  SDValue LoMask, HiMask;
+  std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
+  std::tie(LoMask, HiMask) = SplitMask(N->getOperand(1));
+
+  SDValue UndefPassthru = DAG.getUNDEF(LoVT);
+  Lo = DAG.getNode(ISD::VECTOR_COMPRESS, DL, LoVT, Lo, LoMask, UndefPassthru);
+  Hi = DAG.getNode(ISD::VECTOR_COMPRESS, DL, HiVT, Hi, HiMask, UndefPassthru);
+
+  SDValue StackPtr = DAG.CreateStackTemporary(
+      VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+  MachineFunction &MF = DAG.getMachineFunction();
+  MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
+      MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());
+
+  // We store LoVec and then insert HiVec starting at offset=|1s| in LoMask.
+  SDValue WideMask =
+      DAG.getNode(ISD::ZERO_EXTEND, DL, LoMask.getValueType(), LoMask);
+  SDValue Offset = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, WideMask);
+  Offset = TLI.getVectorElementPointer(DAG, StackPtr, VecVT, Offset);
+
+  SDValue Chain = DAG.getEntryNode();
+  Chain = DAG.getStore(Chain, DL, Lo, StackPtr, PtrInfo);
+  Chain = DAG.getStore(Chain, DL, Hi, Offset,
+                       MachinePointerInfo::getUnknownStack(MF));
+
+  SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
+  std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL);
 }
 
 void DAGTypeLegalizer::SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
@@ -5784,7 +5834,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_VECTOR_COMPRESS(SDNode *N) {
       TLI.getTypeToTransformTo(*DAG.getContext(), Vec.getValueType());
   EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
                                     Mask.getValueType().getVectorElementType(),
-                                    WideVecVT.getVectorNumElements());
+                                    WideVecVT.getVectorElementCount());
 
   SDValue WideVec = ModifyToType(Vec, WideVecVT);
   SDValue WideMask = ModifyToType(Mask, WideMaskVT, /*FillWithZeroes=*/true);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1e9da9b819bdd..6bfeb4d11ec42 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1535,6 +1535,24 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       }
     }
 
+    // We can lower types that have <vscale x {2|4}> elements to svcompact and
+    // legal i8/i16 types via a compressing store.
+    for (auto VT :
+         {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
+          MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32,
+          MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8})
+      setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
+
+    // If we have SVE, we can use SVE logic for legal (or smaller than legal)
+    // NEON vectors in the lowest bits of the SVE register.
+    if (Subtarget->hasSVE())
+      for (auto VT :
+           {MVT::v1i8,  MVT::v1i16, MVT::v1i32, MVT::v1i64, MVT::v1f32,
+            MVT::v1f64, MVT::v2i8,  MVT::v2i16, MVT::v2i32, MVT::v2i64,
+            MVT::v2f32, MVT::v2f64, MVT::v4i8,  MVT::v4i16, MVT::v4i32,
+            MVT::v4f32, MVT::v8i8,  MVT::v8i16, MVT::v8i16, MVT::v16i8})
+        setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
+
     // NEON doesn't support masked loads/stores, but SME and SVE do.
     for (auto VT :
          {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, MVT::v1f64,
@@ -6615,6 +6633,132 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
   return DAG.getMergeValues({Ext, Chain}, DL);
 }
 
+SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
+                                                    SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SDValue Vec = Op.getOperand(0);
+  SDValue Mask = Op.getOperand(1);
+  SDValue Passthru = Op.getOperand(2);
+  EVT VecVT = Vec.getValueType();
+  EVT MaskVT = Mask.getValueType();
+  EVT ElmtVT = VecVT.getVectorElementType();
+  const bool IsFixedLength = VecVT.isFixedLengthVector();
+  const bool HasPassthru = !Passthru.isUndef();
+  unsigned MinElmts = VecVT.getVectorElementCount().getKnownMinValue();
+  EVT FixedVecVT = MVT::getVectorVT(ElmtVT.getSimpleVT(), MinElmts);
+
+  assert(VecVT.isVector() && "Input to VECTOR_COMPRESS must be vector.");
+
+  if (!Subtarget->hasSVE())
+    return SDValue();
+
+  if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
+    return SDValue();
+
+  // We can use the SVE register containing the NEON vector in its lowest bits.
+  if (IsFixedLength) {
+    EVT ScalableVecVT =
+        MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), MinElmts);
+    EVT ScalableMaskVT = MVT::getScalableVectorVT(
+        MaskVT.getVectorElementType().getSimpleVT(), MinElmts);
+
+    Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
+                      DAG.getUNDEF(ScalableVecVT), Vec,
+                      DAG.getConstant(0, DL, MVT::i64));
+    Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableMaskVT,
+                       DAG.getUNDEF(ScalableMaskVT), Mask,
+                       DAG.getConstant(0, DL, MVT::i64));
+    Mask = DAG.getNode(ISD::TRUNCATE, DL,
+                       ScalableMaskVT.changeVectorElementType(MVT::i1), Mask);
+    Passthru = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
+                           DAG.getUNDEF(ScalableVecVT), Passthru,
+                           DAG.getConstant(0, DL, MVT::i64));
+
+    VecVT = Vec.getValueType();
+    MaskVT = Mask.getValueType();
+  }
+
+  // Special case where we can't use svcompact but can do a compressing store
+  // and then reload the vector.
+  if (VecVT == MVT::nxv8i8 || VecVT == MVT::nxv16i8 || VecVT == MVT::nxv8i16) {
+    SDValue StackPtr = DAG.CreateStackTemporary(VecVT);
+    int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+    MachinePointerInfo PtrInfo =
+        MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
+
+    MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
+        PtrInfo, MachineMemOperand::Flags::MOStore,
+        LocationSize::precise(VecVT.getStoreSize()),
+        DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+
+    SDValue Chain = DAG.getEntryNode();
+    if (HasPassthru)
+      Chain = DAG.getStore(Chain, DL, Passthru, StackPtr, PtrInfo);
+
+    Chain = DAG.getMaskedStore(Chain, DL, Vec, StackPtr, DAG.getUNDEF(MVT::i64),
+                               Mask, VecVT, MMO, ISD::UNINDEXED,
+                               /*IsTruncating=*/false, /*IsCompressing=*/true);
+
+    SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
+
+    if (IsFixedLength)
+      Compressed = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FixedVecVT,
+                               Compressed, DAG.getConstant(0, DL, MVT::i64));
+
+    return Compressed;
+  }
+
+  // Only <vscale x {2|4} x {i32|i64}> supported for svcompact.
+  if (MinElmts != 2 && MinElmts != 4)
+    return SDValue();
+
+  // Get legal type for svcompact instruction
+  EVT ContainerVT = getSVEContainerType(VecVT);
+  EVT CastVT = VecVT.changeVectorElementTypeToInteger();
+
+  // Convert to i32 or i64 for smaller types, as these are the only supported
+  // sizes for svcompact.
+  if (ContainerVT != VecVT) {
+    Vec = DAG.getBitcast(CastVT, Vec);
+    Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
+  }
+
+  SDValue Compressed = DAG.getNode(
+      ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
+      DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
+
+  // svcompact fills with 0s, so if our passthru is all 0s, do nothing here.
+  if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
+    SDValue Offset = DAG.getNode(
+        ISD::ZERO_EXTEND, DL, MaskVT.changeVectorElementType(MVT::i32), Mask);
+    Offset = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, Offset);
+    Compressed =
+        DAG.getNode(ISD::VP_MERGE, DL, VecVT,
+                    DAG.getSplatVector(MaskVT, DL,
+                                       DAG.getAllOnesConstant(
+                                           DL, MaskVT.getVectorElementType())),
+                    Compressed, Passthru, Offset);
+  }
+
+  // Extracting from a legal SVE type before truncating produces better code.
+  if (IsFixedLength) {
+    Compressed = DAG.getNode(
+        ISD::EXTRACT_SUBVECTOR, DL,
+        FixedVecVT.changeVectorElementType(ContainerVT.getVectorElementType()),
+        Compressed, DAG.getConstant(0, DL, MVT::i64));
+    CastVT = FixedVecVT.changeVectorElementTypeToInteger();
+    VecVT = FixedVecVT;
+  }
+
+  // If we changed the element type before, we need to convert it back.
+  if (ContainerVT != VecVT) {
+    Compressed = DAG.getNode(ISD::TRUNCATE, DL, CastVT, Compressed);
+    Compressed = DAG.getBitcast(VecVT, Compressed);
+  }
+
+  return Compressed;
+}
+
 // Generate SUBS and CSEL for integer abs.
 SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
   MVT VT = Op.getSimpleValueType();
@@ -6995,6 +7139,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerDYNAMIC_STACKALLOC(Op, DAG);
   case ISD::VSCALE:
     return LowerVSCALE(Op, DAG);
+  case ISD::VECTOR_COMPRESS:
+    return LowerVECTOR_COMPRESS(Op, DAG);
   case ISD::ANY_EXTEND:
   case ISD::SIGN_EXTEND:
   case ISD::ZERO_EXTEND:
@@ -22928,6 +23074,67 @@ static SDValue combineI8TruncStore(StoreSDNode *ST, SelectionDAG &DAG,
   return Chain;
 }
 
+static SDValue combineVECTOR_COMPRESSStore(SelectionDAG &DAG,
+                                           StoreSDNode *Store,
+                                           const AArch64Subtarget *Subtarget) {
+  // If the regular store is preceded by an VECTOR_COMPRESS, we can combine them
+  // into a compressing store for scalable vectors in SVE.
+  SDValue VecOp = Store->getValue();
+  EVT VecVT = VecOp.getValueType();
+  if (VecOp.getOpcode() != ISD::VECTOR_COMPRESS || !Subtarget->hasSVE())
+    return SDValue();
+
+  bool IsFixedLength = VecVT.isFixedLengthVector();
+  if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
+    return SDValue();
+
+  SDLoc DL(Store);
+  SDValue Vec = VecOp.getOperand(0);
+  SDValue Mask = VecOp.getOperand(1);
+  SDValue Passthru = VecOp.getOperand(2);
+  EVT MemVT = Store->getMemoryVT();
+  MachineMemOperand *MMO = Store->getMemOperand();
+  SDValue Chain = Store->getChain();
+
+  // We can use the SVE register containing the NEON vector in its lowest bits.
+  if (IsFixedLength) {
+    EVT ElmtVT = VecVT.getVectorElementType();
+    unsigned NumElmts = VecVT.getVectorNumElements();
+    EVT ScalableVecVT =
+        MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), NumElmts);
+    EVT ScalableMaskVT = MVT::getScalableVectorVT(
+        Mask.getValueType().getVectorElementType().getSimpleVT(), NumElmts);
+
+    Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
+                      DAG.getUNDEF(ScalableVecVT), Vec,
+                      DAG.getConstant(0, DL, MVT::i64));
+    Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableMaskVT,
+                       DAG.getUNDEF(ScalableMaskVT), Mask,
+                       DAG.getConstant(0, DL, MVT::i64));
+    Mask = DAG.getNode(ISD::TRUNCATE, DL,
+                       ScalableMaskVT.changeVectorElementType(MVT::i1), Mask);
+    Passthru = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
+                           DAG.getUNDEF(ScalableVecVT), Passthru,
+                           DAG.getConstant(0, DL, MVT::i64));
+
+    MemVT = ScalableVecVT;
+    MMO->setType(LLT::scalable_vector(NumElmts, ElmtVT.getSizeInBits()));
+  }
+
+  // If the passthru is all 0s, we don't need an explicit passthru store.
+  unsigned MinElmts = VecVT.getVectorMinNumElements();
+  if (ISD::isConstantSplatVectorAllZeros(Passthru.getNode()) && (MinElmts == 2 || MinElmts == 4))
+    return SDValue();
+
+  if (!Passthru.isUndef())
+    Chain = DAG.getStore(Chain, DL, Passthru, Store->getBasePtr(), MMO);
+
+  return DAG.getMaskedStore(Chain, DL, Vec, Store->getBasePtr(),
+                            DAG.getUNDEF(MVT::i64), Mask, MemVT, MMO,
+                            ISD::UNINDEXED, Store->isTruncatingStore(),
+                            /*IsCompressing=*/true);
+}
+
 static SDValue performSTORECombine(SDNode *N,
                                    TargetLowering::DAGCombinerInfo &DCI,
                                    SelectionDAG &DAG,
@@ -22972,6 +23179,9 @@ static SDValue performSTORECombine(SDNode *N,
   if (SDValue Store = combineBoolVectorAndTruncateStore(DAG, ST))
     return Store;
 
+  if (SDValue Store = combineVECTOR_COMPRESSStore(DAG, ST, Subtarget))
+    return Store;
+
   if (ST->isTruncatingStore()) {
     EVT StoreVT = ST->getMemoryVT();
     if (!isHalvingTruncateOfLegalScalableType(ValueVT, StoreVT))
@@ -26214,6 +26424,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
   case ISD::VECREDUCE_UMIN:
     Results.push_back(LowerVECREDUCE(SDValue(N, 0), DAG));
     return;
+  case ISD::VECTOR_COMPRESS:
+    if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
+      Results.push_back(Res);
+    return;
   case ISD::ADD:
   case ISD::FADD:
     ReplaceAddWithADDP(N, Results, DAG, Subtarget);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 81e15185f985d..517b1ba1fd400 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1073,6 +1073,8 @@ class AArch64TargetLowering : public TargetLowering {
 
   SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerVECTOR_COMPRESS(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerINTRINSIC_VOID(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
new file mode 100644
index 0000000000000..cdebb0db47ceb
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
@@ -0,0 +1,496 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64 -mattr=+sve < %s | FileCheck %s
+
+define <vscale x 2 x i8> @test_compress_nxv2i8(<vscale x 2 x i8> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv2i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.d, p0, z0.d
+; CHECK-NEXT:    ret
+    %out = call <vscale x 2 x i8> @llvm.experimental.vector.compress(<vscale x 2 x i8> %vec, <vscale x 2 x i1> %mask, <vscale x 2 x i8> undef)
+    ret <vscale x 2 x i8> %out
+}
+
+define <vscale x 2 x i16> @test_compress_nxv2i16(<vscale x 2 x i16> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv2i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.d, p0, z0.d
+; CHECK-NEXT:    ret
+    %out = call <vscale x 2 x i16> @llvm.experimental.vector.compress(<vscale x 2 x i16> %vec, <vscale x 2 x i1> %mask, <vscale x 2 x i16> undef)
+    ret <vscale x 2 x i16> %out
+}
+
+define <vscale x 2 x i32> @test_compress_nxv2i32(<vscale x 2 x i32> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv2i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.d, p0, z0.d
+; CHECK-NEXT:    ret
+    %out = call <vscale x 2 x i32> @llvm.experimental.vector.compress(<vscale x 2 x i32> %vec, <vscale x 2 x i1> %mask, <vscale x 2 x i32> undef)
+    ret <vscale x 2 x i32> %out
+}
+
+define <vscale x 2 x i64> @test_compress_nxv2i64(<vscale x 2 x i64> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv2i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.d, p0, z0.d
+; CHECK-NEXT:    ret
+    %out = call <vscale x 2 x i64> @llvm.experimental.vector.compress(<vscale x 2 x i64> %vec, <vscale x 2 x i1> %mask, <vscale x 2 x i64> undef)
+    ret <vscale x 2 x i64> %out
+}
+
+define <vscale x 2 x float> @test_compress_nxv2f32(<vscale x 2 x float> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv2f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.d, p0, z0.d
+; CHECK-NEXT:    ret
+    %out = call <vscale x 2 x float> @llvm.experimental.vector.compress(<vscale x 2 x float> %vec, <vscale x 2 x i1> %mask, <vscale x 2 x float> undef)
+    ret <vscale x 2 x float> %out
+}
+
+define <vscale x 2 x double> @test_compress_nxv2f64(<vscale x 2 x double> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.d, p0, z0.d
+; CHECK-NEXT:    ret
+    %out = call <vscale x 2 x double> @llvm.experimental.vector.compress(<vscale x 2 x double> %vec, <vscale x 2 x i1> %mask, <vscale x 2 x double> undef)
+    ret <vscale x 2 x double> %out
+}
+
+define <vscale x 4 x i8> @test_compress_nxv4i8(<vscale x 4 x i8> %vec, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv4i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.s, p0, z0.s
+; CHECK-NEXT:    ret
+    %out = call <vscale x 4 x i8> @llvm.experimental.vector.compress(<vscale x 4 x i8> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i8> undef)
+    ret <vscale x 4 x i8> %out
+}
+
+define <vscale x 4 x i16> @test_compress_nxv4i16(<vscale x 4 x i16> %vec, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv4i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.s, p0, z0.s
+; CHECK-NEXT:    ret
+    %out = call <vscale x 4 x i16> @llvm.experimental.vector.compress(<vscale x 4 x i16> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i16> undef)
+    ret <vscale x 4 x i16> %out
+}
+
+define <vscale x 4 x i32> @test_compress_nxv4i32(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv4i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.s, p0, z0.s
+; CHECK-NEXT:    ret
+    %out = call <vscale x 4 x i32> @llvm.experimental.vector.compress(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i32> undef)
+    ret <vscale x 4 x i32> %out
+}
+
+define <vscale x 4 x float> @test_compress_nxv4f32(<vscale x 4 x float> %vec, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv4f32:
+; CHE...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Lawrence Benson (lawben)

Changes

This is a follow-up to #92289 that adds custom lowering of the new @<!-- -->llvm.experimental.vector.compress intrinsic on AArch64 with SVE instructions.

Some vectors have a compact instruction that they can be lowered to. For other cases, we use SVE's st1b.

TODO: I still need to run this on an SVE machine.


Patch is 38.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101015.diff

4 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+56-6)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+214)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2)
  • (added) llvm/test/CodeGen/AArch64/sve-vector-compress.ll (+496)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 5672b611234b8..b42a54a56cfed 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -2408,11 +2408,61 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS(SDNode *N, SDValue &Lo,
                                                    SDValue &Hi) {
   // This is not "trivial", as there is a dependency between the two subvectors.
   // Depending on the number of 1s in the mask, the elements from the Hi vector
-  // need to be moved to the Lo vector. So we just perform this as one "big"
-  // operation and then extract the Lo and Hi vectors from that. This gets rid
-  // of VECTOR_COMPRESS and all other operands can be legalized later.
-  SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
-  std::tie(Lo, Hi) = DAG.SplitVector(Compressed, SDLoc(N));
+  // need to be moved to the Lo vector. Passthru values make this even harder.
+  // We try to use VECTOR_COMPRESS if the target has custom lowering with
+  // smaller types and passthru is undef, as it is most likely faster than the
+  // fully expand path. Otherwise, just do the full expansion as one "big"
+  // operation and then extract the Lo and Hi vectors from that. This gets
+  // rid of VECTOR_COMPRESS and all other operands can be legalized later.
+  SDLoc DL(N);
+  EVT VecVT = N->getValueType(0);
+
+  auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VecVT);
+  bool HasLegalOrCustom = false;
+  EVT CheckVT = LoVT;
+  while (CheckVT.getVectorMinNumElements() > 1) {
+    if (TLI.isOperationLegalOrCustom(ISD::VECTOR_COMPRESS, CheckVT)) {
+      HasLegalOrCustom = true;
+      break;
+    }
+    CheckVT = CheckVT.getHalfNumVectorElementsVT(*DAG.getContext());
+  }
+
+  SDValue Passthru = N->getOperand(2);
+  if (!HasLegalOrCustom || !Passthru.isUndef()) {
+    SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
+    std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL, LoVT, HiVT);
+    return;
+  }
+
+  // Try to VECTOR_COMPRESS smaller vectors and combine via a stack store+load.
+  SDValue LoMask, HiMask;
+  std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
+  std::tie(LoMask, HiMask) = SplitMask(N->getOperand(1));
+
+  SDValue UndefPassthru = DAG.getUNDEF(LoVT);
+  Lo = DAG.getNode(ISD::VECTOR_COMPRESS, DL, LoVT, Lo, LoMask, UndefPassthru);
+  Hi = DAG.getNode(ISD::VECTOR_COMPRESS, DL, HiVT, Hi, HiMask, UndefPassthru);
+
+  SDValue StackPtr = DAG.CreateStackTemporary(
+      VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+  MachineFunction &MF = DAG.getMachineFunction();
+  MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
+      MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());
+
+  // We store LoVec and then insert HiVec starting at offset=|1s| in LoMask.
+  SDValue WideMask =
+      DAG.getNode(ISD::ZERO_EXTEND, DL, LoMask.getValueType(), LoMask);
+  SDValue Offset = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, WideMask);
+  Offset = TLI.getVectorElementPointer(DAG, StackPtr, VecVT, Offset);
+
+  SDValue Chain = DAG.getEntryNode();
+  Chain = DAG.getStore(Chain, DL, Lo, StackPtr, PtrInfo);
+  Chain = DAG.getStore(Chain, DL, Hi, Offset,
+                       MachinePointerInfo::getUnknownStack(MF));
+
+  SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
+  std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL);
 }
 
 void DAGTypeLegalizer::SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
@@ -5784,7 +5834,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_VECTOR_COMPRESS(SDNode *N) {
       TLI.getTypeToTransformTo(*DAG.getContext(), Vec.getValueType());
   EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
                                     Mask.getValueType().getVectorElementType(),
-                                    WideVecVT.getVectorNumElements());
+                                    WideVecVT.getVectorElementCount());
 
   SDValue WideVec = ModifyToType(Vec, WideVecVT);
   SDValue WideMask = ModifyToType(Mask, WideMaskVT, /*FillWithZeroes=*/true);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1e9da9b819bdd..6bfeb4d11ec42 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1535,6 +1535,24 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       }
     }
 
+    // We can lower types that have <vscale x {2|4}> elements to svcompact and
+    // legal i8/i16 types via a compressing store.
+    for (auto VT :
+         {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
+          MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32,
+          MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8})
+      setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
+
+    // If we have SVE, we can use SVE logic for legal (or smaller than legal)
+    // NEON vectors in the lowest bits of the SVE register.
+    if (Subtarget->hasSVE())
+      for (auto VT :
+           {MVT::v1i8,  MVT::v1i16, MVT::v1i32, MVT::v1i64, MVT::v1f32,
+            MVT::v1f64, MVT::v2i8,  MVT::v2i16, MVT::v2i32, MVT::v2i64,
+            MVT::v2f32, MVT::v2f64, MVT::v4i8,  MVT::v4i16, MVT::v4i32,
+            MVT::v4f32, MVT::v8i8,  MVT::v8i16, MVT::v8i16, MVT::v16i8})
+        setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
+
     // NEON doesn't support masked loads/stores, but SME and SVE do.
     for (auto VT :
          {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, MVT::v1f64,
@@ -6615,6 +6633,132 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
   return DAG.getMergeValues({Ext, Chain}, DL);
 }
 
+SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
+                                                    SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SDValue Vec = Op.getOperand(0);
+  SDValue Mask = Op.getOperand(1);
+  SDValue Passthru = Op.getOperand(2);
+  EVT VecVT = Vec.getValueType();
+  EVT MaskVT = Mask.getValueType();
+  EVT ElmtVT = VecVT.getVectorElementType();
+  const bool IsFixedLength = VecVT.isFixedLengthVector();
+  const bool HasPassthru = !Passthru.isUndef();
+  unsigned MinElmts = VecVT.getVectorElementCount().getKnownMinValue();
+  EVT FixedVecVT = MVT::getVectorVT(ElmtVT.getSimpleVT(), MinElmts);
+
+  assert(VecVT.isVector() && "Input to VECTOR_COMPRESS must be vector.");
+
+  if (!Subtarget->hasSVE())
+    return SDValue();
+
+  if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
+    return SDValue();
+
+  // We can use the SVE register containing the NEON vector in its lowest bits.
+  if (IsFixedLength) {
+    EVT ScalableVecVT =
+        MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), MinElmts);
+    EVT ScalableMaskVT = MVT::getScalableVectorVT(
+        MaskVT.getVectorElementType().getSimpleVT(), MinElmts);
+
+    Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
+                      DAG.getUNDEF(ScalableVecVT), Vec,
+                      DAG.getConstant(0, DL, MVT::i64));
+    Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableMaskVT,
+                       DAG.getUNDEF(ScalableMaskVT), Mask,
+                       DAG.getConstant(0, DL, MVT::i64));
+    Mask = DAG.getNode(ISD::TRUNCATE, DL,
+                       ScalableMaskVT.changeVectorElementType(MVT::i1), Mask);
+    Passthru = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
+                           DAG.getUNDEF(ScalableVecVT), Passthru,
+                           DAG.getConstant(0, DL, MVT::i64));
+
+    VecVT = Vec.getValueType();
+    MaskVT = Mask.getValueType();
+  }
+
+  // Special case where we can't use svcompact but can do a compressing store
+  // and then reload the vector.
+  if (VecVT == MVT::nxv8i8 || VecVT == MVT::nxv16i8 || VecVT == MVT::nxv8i16) {
+    SDValue StackPtr = DAG.CreateStackTemporary(VecVT);
+    int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+    MachinePointerInfo PtrInfo =
+        MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
+
+    MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
+        PtrInfo, MachineMemOperand::Flags::MOStore,
+        LocationSize::precise(VecVT.getStoreSize()),
+        DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+
+    SDValue Chain = DAG.getEntryNode();
+    if (HasPassthru)
+      Chain = DAG.getStore(Chain, DL, Passthru, StackPtr, PtrInfo);
+
+    Chain = DAG.getMaskedStore(Chain, DL, Vec, StackPtr, DAG.getUNDEF(MVT::i64),
+                               Mask, VecVT, MMO, ISD::UNINDEXED,
+                               /*IsTruncating=*/false, /*IsCompressing=*/true);
+
+    SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
+
+    if (IsFixedLength)
+      Compressed = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FixedVecVT,
+                               Compressed, DAG.getConstant(0, DL, MVT::i64));
+
+    return Compressed;
+  }
+
+  // Only <vscale x {2|4} x {i32|i64}> supported for svcompact.
+  if (MinElmts != 2 && MinElmts != 4)
+    return SDValue();
+
+  // Get legal type for svcompact instruction
+  EVT ContainerVT = getSVEContainerType(VecVT);
+  EVT CastVT = VecVT.changeVectorElementTypeToInteger();
+
+  // Convert to i32 or i64 for smaller types, as these are the only supported
+  // sizes for svcompact.
+  if (ContainerVT != VecVT) {
+    Vec = DAG.getBitcast(CastVT, Vec);
+    Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
+  }
+
+  SDValue Compressed = DAG.getNode(
+      ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
+      DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
+
+  // svcompact fills with 0s, so if our passthru is all 0s, do nothing here.
+  if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
+    SDValue Offset = DAG.getNode(
+        ISD::ZERO_EXTEND, DL, MaskVT.changeVectorElementType(MVT::i32), Mask);
+    Offset = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, Offset);
+    Compressed =
+        DAG.getNode(ISD::VP_MERGE, DL, VecVT,
+                    DAG.getSplatVector(MaskVT, DL,
+                                       DAG.getAllOnesConstant(
+                                           DL, MaskVT.getVectorElementType())),
+                    Compressed, Passthru, Offset);
+  }
+
+  // Extracting from a legal SVE type before truncating produces better code.
+  if (IsFixedLength) {
+    Compressed = DAG.getNode(
+        ISD::EXTRACT_SUBVECTOR, DL,
+        FixedVecVT.changeVectorElementType(ContainerVT.getVectorElementType()),
+        Compressed, DAG.getConstant(0, DL, MVT::i64));
+    CastVT = FixedVecVT.changeVectorElementTypeToInteger();
+    VecVT = FixedVecVT;
+  }
+
+  // If we changed the element type before, we need to convert it back.
+  if (ContainerVT != VecVT) {
+    Compressed = DAG.getNode(ISD::TRUNCATE, DL, CastVT, Compressed);
+    Compressed = DAG.getBitcast(VecVT, Compressed);
+  }
+
+  return Compressed;
+}
+
 // Generate SUBS and CSEL for integer abs.
 SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
   MVT VT = Op.getSimpleValueType();
@@ -6995,6 +7139,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerDYNAMIC_STACKALLOC(Op, DAG);
   case ISD::VSCALE:
     return LowerVSCALE(Op, DAG);
+  case ISD::VECTOR_COMPRESS:
+    return LowerVECTOR_COMPRESS(Op, DAG);
   case ISD::ANY_EXTEND:
   case ISD::SIGN_EXTEND:
   case ISD::ZERO_EXTEND:
@@ -22928,6 +23074,67 @@ static SDValue combineI8TruncStore(StoreSDNode *ST, SelectionDAG &DAG,
   return Chain;
 }
 
+static SDValue combineVECTOR_COMPRESSStore(SelectionDAG &DAG,
+                                           StoreSDNode *Store,
+                                           const AArch64Subtarget *Subtarget) {
+  // If the regular store is preceded by an VECTOR_COMPRESS, we can combine them
+  // into a compressing store for scalable vectors in SVE.
+  SDValue VecOp = Store->getValue();
+  EVT VecVT = VecOp.getValueType();
+  if (VecOp.getOpcode() != ISD::VECTOR_COMPRESS || !Subtarget->hasSVE())
+    return SDValue();
+
+  bool IsFixedLength = VecVT.isFixedLengthVector();
+  if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
+    return SDValue();
+
+  SDLoc DL(Store);
+  SDValue Vec = VecOp.getOperand(0);
+  SDValue Mask = VecOp.getOperand(1);
+  SDValue Passthru = VecOp.getOperand(2);
+  EVT MemVT = Store->getMemoryVT();
+  MachineMemOperand *MMO = Store->getMemOperand();
+  SDValue Chain = Store->getChain();
+
+  // We can use the SVE register containing the NEON vector in its lowest bits.
+  if (IsFixedLength) {
+    EVT ElmtVT = VecVT.getVectorElementType();
+    unsigned NumElmts = VecVT.getVectorNumElements();
+    EVT ScalableVecVT =
+        MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), NumElmts);
+    EVT ScalableMaskVT = MVT::getScalableVectorVT(
+        Mask.getValueType().getVectorElementType().getSimpleVT(), NumElmts);
+
+    Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
+                      DAG.getUNDEF(ScalableVecVT), Vec,
+                      DAG.getConstant(0, DL, MVT::i64));
+    Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableMaskVT,
+                       DAG.getUNDEF(ScalableMaskVT), Mask,
+                       DAG.getConstant(0, DL, MVT::i64));
+    Mask = DAG.getNode(ISD::TRUNCATE, DL,
+                       ScalableMaskVT.changeVectorElementType(MVT::i1), Mask);
+    Passthru = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
+                           DAG.getUNDEF(ScalableVecVT), Passthru,
+                           DAG.getConstant(0, DL, MVT::i64));
+
+    MemVT = ScalableVecVT;
+    MMO->setType(LLT::scalable_vector(NumElmts, ElmtVT.getSizeInBits()));
+  }
+
+  // If the passthru is all 0s, we don't need an explicit passthru store.
+  unsigned MinElmts = VecVT.getVectorMinNumElements();
+  if (ISD::isConstantSplatVectorAllZeros(Passthru.getNode()) && (MinElmts == 2 || MinElmts == 4))
+    return SDValue();
+
+  if (!Passthru.isUndef())
+    Chain = DAG.getStore(Chain, DL, Passthru, Store->getBasePtr(), MMO);
+
+  return DAG.getMaskedStore(Chain, DL, Vec, Store->getBasePtr(),
+                            DAG.getUNDEF(MVT::i64), Mask, MemVT, MMO,
+                            ISD::UNINDEXED, Store->isTruncatingStore(),
+                            /*IsCompressing=*/true);
+}
+
 static SDValue performSTORECombine(SDNode *N,
                                    TargetLowering::DAGCombinerInfo &DCI,
                                    SelectionDAG &DAG,
@@ -22972,6 +23179,9 @@ static SDValue performSTORECombine(SDNode *N,
   if (SDValue Store = combineBoolVectorAndTruncateStore(DAG, ST))
     return Store;
 
+  if (SDValue Store = combineVECTOR_COMPRESSStore(DAG, ST, Subtarget))
+    return Store;
+
   if (ST->isTruncatingStore()) {
     EVT StoreVT = ST->getMemoryVT();
     if (!isHalvingTruncateOfLegalScalableType(ValueVT, StoreVT))
@@ -26214,6 +26424,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
   case ISD::VECREDUCE_UMIN:
     Results.push_back(LowerVECREDUCE(SDValue(N, 0), DAG));
     return;
+  case ISD::VECTOR_COMPRESS:
+    if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
+      Results.push_back(Res);
+    return;
   case ISD::ADD:
   case ISD::FADD:
     ReplaceAddWithADDP(N, Results, DAG, Subtarget);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 81e15185f985d..517b1ba1fd400 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1073,6 +1073,8 @@ class AArch64TargetLowering : public TargetLowering {
 
   SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerVECTOR_COMPRESS(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerINTRINSIC_VOID(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
new file mode 100644
index 0000000000000..cdebb0db47ceb
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
@@ -0,0 +1,496 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64 -mattr=+sve < %s | FileCheck %s
+
+define <vscale x 2 x i8> @test_compress_nxv2i8(<vscale x 2 x i8> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv2i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.d, p0, z0.d
+; CHECK-NEXT:    ret
+    %out = call <vscale x 2 x i8> @llvm.experimental.vector.compress(<vscale x 2 x i8> %vec, <vscale x 2 x i1> %mask, <vscale x 2 x i8> undef)
+    ret <vscale x 2 x i8> %out
+}
+
+define <vscale x 2 x i16> @test_compress_nxv2i16(<vscale x 2 x i16> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv2i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.d, p0, z0.d
+; CHECK-NEXT:    ret
+    %out = call <vscale x 2 x i16> @llvm.experimental.vector.compress(<vscale x 2 x i16> %vec, <vscale x 2 x i1> %mask, <vscale x 2 x i16> undef)
+    ret <vscale x 2 x i16> %out
+}
+
+define <vscale x 2 x i32> @test_compress_nxv2i32(<vscale x 2 x i32> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv2i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.d, p0, z0.d
+; CHECK-NEXT:    ret
+    %out = call <vscale x 2 x i32> @llvm.experimental.vector.compress(<vscale x 2 x i32> %vec, <vscale x 2 x i1> %mask, <vscale x 2 x i32> undef)
+    ret <vscale x 2 x i32> %out
+}
+
+define <vscale x 2 x i64> @test_compress_nxv2i64(<vscale x 2 x i64> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv2i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.d, p0, z0.d
+; CHECK-NEXT:    ret
+    %out = call <vscale x 2 x i64> @llvm.experimental.vector.compress(<vscale x 2 x i64> %vec, <vscale x 2 x i1> %mask, <vscale x 2 x i64> undef)
+    ret <vscale x 2 x i64> %out
+}
+
+define <vscale x 2 x float> @test_compress_nxv2f32(<vscale x 2 x float> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv2f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.d, p0, z0.d
+; CHECK-NEXT:    ret
+    %out = call <vscale x 2 x float> @llvm.experimental.vector.compress(<vscale x 2 x float> %vec, <vscale x 2 x i1> %mask, <vscale x 2 x float> undef)
+    ret <vscale x 2 x float> %out
+}
+
+define <vscale x 2 x double> @test_compress_nxv2f64(<vscale x 2 x double> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.d, p0, z0.d
+; CHECK-NEXT:    ret
+    %out = call <vscale x 2 x double> @llvm.experimental.vector.compress(<vscale x 2 x double> %vec, <vscale x 2 x i1> %mask, <vscale x 2 x double> undef)
+    ret <vscale x 2 x double> %out
+}
+
+define <vscale x 4 x i8> @test_compress_nxv4i8(<vscale x 4 x i8> %vec, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv4i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.s, p0, z0.s
+; CHECK-NEXT:    ret
+    %out = call <vscale x 4 x i8> @llvm.experimental.vector.compress(<vscale x 4 x i8> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i8> undef)
+    ret <vscale x 4 x i8> %out
+}
+
+define <vscale x 4 x i16> @test_compress_nxv4i16(<vscale x 4 x i16> %vec, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv4i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.s, p0, z0.s
+; CHECK-NEXT:    ret
+    %out = call <vscale x 4 x i16> @llvm.experimental.vector.compress(<vscale x 4 x i16> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i16> undef)
+    ret <vscale x 4 x i16> %out
+}
+
+define <vscale x 4 x i32> @test_compress_nxv4i32(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv4i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    compact z0.s, p0, z0.s
+; CHECK-NEXT:    ret
+    %out = call <vscale x 4 x i32> @llvm.experimental.vector.compress(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i32> undef)
+    ret <vscale x 4 x i32> %out
+}
+
+define <vscale x 4 x float> @test_compress_nxv4f32(<vscale x 4 x float> %vec, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv4f32:
+; CHE...
[truncated]

Copy link

github-actions bot commented Jul 29, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
ISD::ZERO_EXTEND, DL, MaskVT.changeVectorElementType(MVT::i32), Mask);
Offset = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, Offset);
Compressed =
DAG.getNode(ISD::VP_MERGE, DL, VecVT,
Copy link
Collaborator

@davemgreen davemgreen Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VP_MERGE are not really supported or encouraged by the AArch64 backend. Is there an alternative we can emit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would be the AArch64-way to express this logic? I copied it from a RICSV example that came up in the original discussion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a passthru vector is given, all remaining lanes are filled with the
corresponding lane's value from passthru.

What does 'corresponding lane' mean in this case? If the mask is <1, 0, 0, 1>, would the passthru for the zero'ed lanes expected to be <_, _, p, p> or <_, p, p, _> (where 'p' means the passthru value and '_' for don't care)

If the former, then I guess you could do a popcount of the predicate, create a mask from that, and then do a vector select?

This also makes me wonder, would it be better to define the intrinsic to make the other lanes undefined, rather than adding a passthru parameter to the intrinsic? That would make the operation easier to codegen, and we can use existing intrinsics to implement the passthru behaviour.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The passthru exists because it's useful for some combinations of target/passthru value. For SVE in particular, for a non-zero passthru, we need to explicitly construct a mask, but other targets support it directly. This was discussed in #92289.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sdesmalen-arm The "corresponding lanes" are the remainder, used to fill up empty slots in the output. So vec=<a, b, c, d>, mask=<1, 0, 0, 1>, passthru=<w, x, y, z> would result in <a, d, y, z>. In your example, it would be <_, _, p, p>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davemgreen using VP_MERGE results in the following assembly (in one of the tests)

cntp	x8, p0, p0.s
index	z2.s, #0, #1
compact	z0.s, p0, z0.s
ptrue	p1.s
mov	z3.s, w8
cmphi	p1.s, p1/z, z3.s, z2.s
sel	z0.s, p1, z0.s, z1.s

So it is doing the right thing. I could manually add the instructions instead of using VP_MERGE, but I'm not sure that makes sense. What would you suggest to do here?

llvm/test/CodeGen/AArch64/sve-vector-compress.ll Outdated Show resolved Hide resolved
@davemgreen davemgreen requested a review from sdesmalen-arm July 29, 2024 17:29
@lawben lawben marked this pull request as draft July 30, 2024 15:23
@lawben
Copy link
Contributor Author

lawben commented Jul 30, 2024

So it seems I misunderstood the st1w instruction. I thought the "contiguous store" is a "compressing" store. But that does not seem the be the case. So all the paths that do a masked store with isCompressing=true are wrong. Sorry :(

I'll fix this and convert this PR back to a normal one once it is fixed. But it looks like we can only use compact for 32 and 64 bit elements. That will shrink the PR by quite a bit.

@davemgreen
Copy link
Collaborator

For i8/i16 would it be possibly to expand the vectors to multiple i32 vectors, perform the compact, then shrink the result down again? It might not be the prettiest/fastest code, but should hopefully allow them to be supported.

@lawben
Copy link
Contributor Author

lawben commented Jul 31, 2024

For i8/i16 would it be possibly to expand the vectors to multiple i32 vectors, perform the compact, then shrink the result down again? It might not be the prettiest/fastest code, but should hopefully allow them to be supported.

@davemgreen I'm working on this right now. For i8/i16 vectors with 2 or 4 elements, this is trivial, as we can just extend the "container" vector. For vectors with more elements, e.g., <vscale x 16 x i8> I'm trying to convert this to <vscale x 16 x i32> and then let DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS handle the illegal type again. But I'm stuck here a bit. After returning a VECTOR_COMPRESS with a 16 x i32 vector from the custom lowering, the next step tries to expand this, which fails for scalable vectors.

Is there a way to "ignore" this for now and let it get legalized first, which contains the logic for recursively splitting and merging large vectors with VECTOR_COMPRESS?

@lawben
Copy link
Contributor Author

lawben commented Aug 7, 2024

Converting back to a regular PR, as this now does the correct thing.

I tried to implement this in GlobalISel too, but there seem to be some issues with SVE there and non-legal types. So I cannot implement the same behavior in both SelectionDAG and GlobalISel. Does it make sense to skip GlobalISel for now or implement a part of the feature and then add stuff later? This can also be in a separate PR.

@lawben lawben marked this pull request as ready for review August 7, 2024 08:53
@davemgreen
Copy link
Collaborator

I tried to implement this in GlobalISel too, but there seem to be some issues with SVE there and non-legal types. So I cannot implement the same behavior in both SelectionDAG and GlobalISel. Does it make sense to skip GlobalISel for now or implement a part of the feature and then add stuff later? This can also be in a separate PR.

Yeah - There isn't a lot of GISel SVE support at the moment, just the very basics.

@davemgreen I'm working on this right now. For i8/i16 vectors with 2 or 4 elements, this is trivial, as we can just extend the "container" vector. For vectors with more elements, e.g., <vscale x 16 x i8> I'm trying to convert this to <vscale x 16 x i32> and then let DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS handle the illegal type again. But I'm stuck here a bit. After returning a VECTOR_COMPRESS with a 16 x i32 vector from the custom lowering, the next step tries to expand this, which fails for scalable vectors.

Is there a way to "ignore" this for now and let it get legalized first, which contains the logic for recursively splitting and merging large vectors with VECTOR_COMPRESS?

Sorry for not replying earlier - my understanding is that we are going from "no support" to "some support", so something working is better than nothing!

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I was not around yesterday. This LGTM from what I can tell, thanks for the updates.

bool HasCustomLowering = false;
EVT CheckVT = LoVT;
while (CheckVT.getVectorMinNumElements() > 1) {
if (TLI.isOperationCustom(ISD::VECTOR_COMPRESS, CheckVT)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isOperationLegalOrCustom, just in case it can be legal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used that originally, but it did not work in all cases. The problem here is that isOperationLegalOrCustom() requires that the type is legal. Types like <4 x i8> are not legal on AArch64 (so the check fails) but we have a custom lowering for it.

I think you are right and we should also catch the legal case, so I'll change this to isOperationLegal() || isOperationCustom(). It's a bit unfortunate that this is not equivalent to isOperationLegalOrCustom().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Immediately after sending the reply I realized that this doesn't matter that much. The result type is only split until it reaches a legal type. So we never actually get into the case where we split this from <8 x i8> to <4 x i8>. So in any case we would need to duplicate the split + compress + merge logic in AArch64ISelLowering in a follow up PR.

I'll still use the isOperationLegal() || isOperationCustom() approach though, becasue it is more general.

@lawben lawben merged commit e0ad56b into llvm:main Aug 13, 2024
8 checks passed
bwendling pushed a commit to bwendling/llvm-project that referenced this pull request Aug 15, 2024
…#101015)

This is a follow-up to llvm#92289 that adds custom lowering of the new
`@llvm.experimental.vector.compress` intrinsic on AArch64 with SVE
instructions.

Some vectors have a `compact` instruction that they can be lowered to.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants