Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] Extend custom lowering for SVE types in @llvm.experimental.vector.compress #105515

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

lawben
Copy link
Contributor

@lawben lawben commented Aug 21, 2024

This is a follow-up to #101015. We now support @llvm.experimental.vector.compress for SVE types that don't map directly to compact, i.e., <vscale x 8 x ..> and <vscale x 16 x ..>. We can also use this logic for corresponding NEON vectors.

@llvmbot
Copy link
Member

llvmbot commented Aug 21, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Lawrence Benson (lawben)

Changes

This is a follow-up to #101015. We now support @<!-- -->llvm.experimental.vector.compress for SVE types that don't map directly to compact, i.e., &lt;vscale x 8 x ..&gt; and &lt;vscale x 16 x ..&gt;. We can also use this logic for corresponding NEON vectors.


Full diff: https://github.com/llvm/llvm-project/pull/105515.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+65-16)
  • (modified) llvm/test/CodeGen/AArch64/sve-vector-compress.ll (+129)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e1d265fdf0d1a8..f4d3fa114ddc3d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1781,16 +1781,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                     MVT::v2f32, MVT::v4f32, MVT::v2f64})
       setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
 
-    // We can lower types that have <vscale x {2|4}> elements to compact.
+    // We can lower all legal (or smaller) SVE types to `compact`.
     for (auto VT :
          {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
-          MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
+          MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32,
+          MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8})
       setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
 
     // If we have SVE, we can use SVE logic for legal (or smaller than legal)
     // NEON vectors in the lowest bits of the SVE register.
     for (auto VT : {MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64, MVT::v2f32,
-                    MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32})
+                    MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32,
+                    MVT::v8i8, MVT::v8i16, MVT::v16i8})
       setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
 
     // Histcnt is SVE2 only
@@ -6659,10 +6661,6 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
   if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
     return SDValue();
 
-  // Only <vscale x {4|2} x {i32|i64}> supported for compact.
-  if (MinElmts != 2 && MinElmts != 4)
-    return SDValue();
-
   // We can use the SVE register containing the NEON vector in its lowest bits.
   if (IsFixedLength) {
     EVT ScalableVecVT =
@@ -6690,16 +6688,67 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
   EVT ContainerVT = getSVEContainerType(VecVT);
   EVT CastVT = VecVT.changeVectorElementTypeToInteger();
 
-  // Convert to i32 or i64 for smaller types, as these are the only supported
-  // sizes for compact.
-  if (ContainerVT != VecVT) {
-    Vec = DAG.getBitcast(CastVT, Vec);
-    Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
-  }
+  // These vector types aren't supported by the `compact` instruction, so
+  // we split and compact them as <vscale x 4 x i32>, store them on the stack,
+  // and then merge them again. In the other cases, emit compact directly.
+  SDValue Compressed;
+  if (VecVT == MVT::nxv8i16 || VecVT == MVT::nxv8i8 || VecVT == MVT::nxv16i8) {
+    SDValue Chain = DAG.getEntryNode();
+    SDValue StackPtr = DAG.CreateStackTemporary(
+        VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+    MachineFunction &MF = DAG.getMachineFunction();
+
+    EVT PartialVecVT =
+        EVT::getVectorVT(*DAG.getContext(), ElmtVT, 4, /*isScalable*/ true);
+    EVT OffsetVT = getVectorIdxTy(DAG.getDataLayout());
+    SDValue Offset = DAG.getConstant(0, DL, OffsetVT);
+
+    for (unsigned I = 0; I < MinElmts; I += 4) {
+      SDValue PartialVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, PartialVecVT,
+                                       Vec, DAG.getVectorIdxConstant(I, DL));
+      PartialVec = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv4i32, PartialVec);
+
+      SDValue PartialMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv4i1,
+                                        Mask, DAG.getVectorIdxConstant(I, DL));
+
+      SDValue PartialCompressed = DAG.getNode(
+          ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32,
+          DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64),
+          PartialMask, PartialVec);
+      PartialCompressed =
+          DAG.getNode(ISD::TRUNCATE, DL, PartialVecVT, PartialCompressed);
+
+      SDValue OutPtr = DAG.getNode(
+          ISD::ADD, DL, StackPtr.getValueType(), StackPtr,
+          DAG.getNode(
+              ISD::MUL, DL, OffsetVT, Offset,
+              DAG.getConstant(ElmtVT.getScalarSizeInBits() / 8, DL, OffsetVT)));
+      Chain = DAG.getStore(Chain, DL, PartialCompressed, OutPtr,
+                           MachinePointerInfo::getUnknownStack(MF));
+
+      SDValue PartialOffset = DAG.getNode(
+          ISD::INTRINSIC_WO_CHAIN, DL, OffsetVT,
+          DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64),
+          PartialMask, PartialMask);
+      Offset = DAG.getNode(ISD::ADD, DL, OffsetVT, Offset, PartialOffset);
+    }
+
+    MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
+        MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());
+    Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
+  } else {
+    // Convert to i32 or i64 for smaller types, as these are the only supported
+    // sizes for compact.
+    if (ContainerVT != VecVT) {
+      Vec = DAG.getBitcast(CastVT, Vec);
+      Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
+    }
 
-  SDValue Compressed = DAG.getNode(
-      ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
-      DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
+    Compressed = DAG.getNode(
+        ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
+        DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask,
+        Vec);
+  }
 
   // compact fills with 0s, so if our passthru is all 0s, do nothing here.
   if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
diff --git a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
index 84c15e4fbc33c7..fc8cbea0d47156 100644
--- a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
+++ b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
@@ -91,6 +91,101 @@ define <vscale x 4 x float> @test_compress_nxv4f32(<vscale x 4 x float> %vec, <v
     ret <vscale x 4 x float> %out
 }
 
+define <vscale x 8 x i8> @test_compress_nxv8i8(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv8i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    uunpklo z1.s, z0.h
+; CHECK-NEXT:    uunpkhi z0.s, z0.h
+; CHECK-NEXT:    addpl x9, sp, #4
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    cntp x8, p1, p1.s
+; CHECK-NEXT:    compact z1.s, p1, z1.s
+; CHECK-NEXT:    compact z0.s, p0, z0.s
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    ptrue p1.h
+; CHECK-NEXT:    st1b { z1.s }, p0, [sp, #2, mul vl]
+; CHECK-NEXT:    st1b { z0.s }, p0, [x9, x8]
+; CHECK-NEXT:    ld1b { z0.h }, p1/z, [sp, #1, mul vl]
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+    %out = call <vscale x 8 x i8> @llvm.experimental.vector.compress(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef)
+    ret <vscale x 8 x i8> %out
+}
+
+define <vscale x 8 x i16> @test_compress_nxv8i16(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv8i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    uunpklo z1.s, z0.h
+; CHECK-NEXT:    uunpkhi z0.s, z0.h
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    cntp x8, p1, p1.s
+; CHECK-NEXT:    compact z1.s, p1, z1.s
+; CHECK-NEXT:    compact z0.s, p0, z0.s
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    ptrue p1.h
+; CHECK-NEXT:    st1h { z1.s }, p0, [sp]
+; CHECK-NEXT:    st1h { z0.s }, p0, [x9, x8, lsl #1]
+; CHECK-NEXT:    ld1h { z0.h }, p1/z, [sp]
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+    %out = call <vscale x 8 x i16> @llvm.experimental.vector.compress(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i16> undef)
+    ret <vscale x 8 x i16> %out
+}
+
+define <vscale x 16 x i8> @test_compress_nxv16i8(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv16i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    uunpklo z1.h, z0.b
+; CHECK-NEXT:    punpklo p2.h, p0.b
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    uunpkhi z0.h, z0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    punpklo p3.h, p2.b
+; CHECK-NEXT:    punpkhi p2.h, p2.b
+; CHECK-NEXT:    uunpklo z2.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    cntp x8, p3, p3.s
+; CHECK-NEXT:    uunpklo z3.s, z0.h
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    uunpkhi z0.s, z0.h
+; CHECK-NEXT:    compact z2.s, p3, z2.s
+; CHECK-NEXT:    compact z1.s, p2, z1.s
+; CHECK-NEXT:    punpklo p3.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    compact z0.s, p0, z0.s
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    st1b { z2.s }, p1, [sp]
+; CHECK-NEXT:    st1b { z1.s }, p1, [x9, x8]
+; CHECK-NEXT:    compact z1.s, p3, z3.s
+; CHECK-NEXT:    incp x8, p2.s
+; CHECK-NEXT:    st1b { z1.s }, p1, [x9, x8]
+; CHECK-NEXT:    incp x8, p3.s
+; CHECK-NEXT:    st1b { z0.s }, p1, [x9, x8]
+; CHECK-NEXT:    ld1b { z0.b }, p0/z, [sp]
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+    %out = call <vscale x 16 x i8> @llvm.experimental.vector.compress(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
+    ret <vscale x 16 x i8> %out
+}
+
 define <vscale x 4 x i4> @test_compress_illegal_element_type(<vscale x 4 x i4> %vec, <vscale x 4 x i1> %mask) {
 ; CHECK-LABEL: test_compress_illegal_element_type:
 ; CHECK:       // %bb.0:
@@ -240,6 +335,40 @@ define <2 x i16> @test_compress_v2i16_with_sve(<2 x i16> %vec, <2 x i1> %mask) {
     ret <2 x i16> %out
 }
 
+define <8 x i16> @test_compress_v8i16_with_sve(<8 x i16> %vec, <8 x i1> %mask) {
+; CHECK-LABEL: test_compress_v8i16_with_sve:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    shl v1.8h, v1.8h, #15
+; CHECK-NEXT:    cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT:    and z1.h, z1.h, #0x1
+; CHECK-NEXT:    cmpne p1.h, p0/z, z1.h, #0
+; CHECK-NEXT:    uunpklo z1.s, z0.h
+; CHECK-NEXT:    uunpkhi z0.s, z0.h
+; CHECK-NEXT:    punpklo p2.h, p1.b
+; CHECK-NEXT:    punpkhi p1.h, p1.b
+; CHECK-NEXT:    compact z1.s, p2, z1.s
+; CHECK-NEXT:    cntp x8, p2, p2.s
+; CHECK-NEXT:    compact z0.s, p1, z0.s
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    st1h { z1.s }, p1, [sp]
+; CHECK-NEXT:    st1h { z0.s }, p1, [x9, x8, lsl #1]
+; CHECK-NEXT:    ld1h { z0.h }, p0/z, [sp]
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+    %out = call <8 x i16> @llvm.experimental.vector.compress(<8 x i16> %vec, <8 x i1> %mask, <8 x i16> undef)
+    ret <8 x i16> %out
+}
+
 
 define <vscale x 4 x i32> @test_compress_nxv4i32_with_passthru(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) {
 ; CHECK-LABEL: test_compress_nxv4i32_with_passthru:

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests for f16/bf16.

@lawben
Copy link
Contributor Author

lawben commented Sep 11, 2024

I've found a bug during manual testing and I'm not sure yet what the issue is. So I'd put this PR is on hold for now, until I've fixed this. As I'm traveling for ~10 days, this will probably take two weeks or so.

It looks like when I extract subvectors to use compact for larger types and store them back to memory, the vectors are overwriting each other. Not sure if this has something to do with store reordering or something similar.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants