-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] Extend custom lowering for SVE types in @llvm.experimental.vector.compress
#105515
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Lawrence Benson (lawben) ChangesThis is a follow-up to #101015. We now support Full diff: https://github.com/llvm/llvm-project/pull/105515.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e1d265fdf0d1a8..f4d3fa114ddc3d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1781,16 +1781,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::v2f32, MVT::v4f32, MVT::v2f64})
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
- // We can lower types that have <vscale x {2|4}> elements to compact.
+ // We can lower all legal (or smaller) SVE types to `compact`.
for (auto VT :
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
- MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
+ MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32,
+ MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
for (auto VT : {MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64, MVT::v2f32,
- MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32})
+ MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32,
+ MVT::v8i8, MVT::v8i16, MVT::v16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
// Histcnt is SVE2 only
@@ -6659,10 +6661,6 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
return SDValue();
- // Only <vscale x {4|2} x {i32|i64}> supported for compact.
- if (MinElmts != 2 && MinElmts != 4)
- return SDValue();
-
// We can use the SVE register containing the NEON vector in its lowest bits.
if (IsFixedLength) {
EVT ScalableVecVT =
@@ -6690,16 +6688,67 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
EVT ContainerVT = getSVEContainerType(VecVT);
EVT CastVT = VecVT.changeVectorElementTypeToInteger();
- // Convert to i32 or i64 for smaller types, as these are the only supported
- // sizes for compact.
- if (ContainerVT != VecVT) {
- Vec = DAG.getBitcast(CastVT, Vec);
- Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
- }
+ // These vector types aren't supported by the `compact` instruction, so
+ // we split and compact them as <vscale x 4 x i32>, store them on the stack,
+ // and then merge them again. In the other cases, emit compact directly.
+ SDValue Compressed;
+ if (VecVT == MVT::nxv8i16 || VecVT == MVT::nxv8i8 || VecVT == MVT::nxv16i8) {
+ SDValue Chain = DAG.getEntryNode();
+ SDValue StackPtr = DAG.CreateStackTemporary(
+ VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+ MachineFunction &MF = DAG.getMachineFunction();
+
+ EVT PartialVecVT =
+ EVT::getVectorVT(*DAG.getContext(), ElmtVT, 4, /*isScalable*/ true);
+ EVT OffsetVT = getVectorIdxTy(DAG.getDataLayout());
+ SDValue Offset = DAG.getConstant(0, DL, OffsetVT);
+
+ for (unsigned I = 0; I < MinElmts; I += 4) {
+ SDValue PartialVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, PartialVecVT,
+ Vec, DAG.getVectorIdxConstant(I, DL));
+ PartialVec = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv4i32, PartialVec);
+
+ SDValue PartialMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv4i1,
+ Mask, DAG.getVectorIdxConstant(I, DL));
+
+ SDValue PartialCompressed = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32,
+ DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64),
+ PartialMask, PartialVec);
+ PartialCompressed =
+ DAG.getNode(ISD::TRUNCATE, DL, PartialVecVT, PartialCompressed);
+
+ SDValue OutPtr = DAG.getNode(
+ ISD::ADD, DL, StackPtr.getValueType(), StackPtr,
+ DAG.getNode(
+ ISD::MUL, DL, OffsetVT, Offset,
+ DAG.getConstant(ElmtVT.getScalarSizeInBits() / 8, DL, OffsetVT)));
+ Chain = DAG.getStore(Chain, DL, PartialCompressed, OutPtr,
+ MachinePointerInfo::getUnknownStack(MF));
+
+ SDValue PartialOffset = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, OffsetVT,
+ DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64),
+ PartialMask, PartialMask);
+ Offset = DAG.getNode(ISD::ADD, DL, OffsetVT, Offset, PartialOffset);
+ }
+
+ MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
+ MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());
+ Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
+ } else {
+ // Convert to i32 or i64 for smaller types, as these are the only supported
+ // sizes for compact.
+ if (ContainerVT != VecVT) {
+ Vec = DAG.getBitcast(CastVT, Vec);
+ Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
+ }
- SDValue Compressed = DAG.getNode(
- ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
- DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
+ Compressed = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
+ DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask,
+ Vec);
+ }
// compact fills with 0s, so if our passthru is all 0s, do nothing here.
if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
diff --git a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
index 84c15e4fbc33c7..fc8cbea0d47156 100644
--- a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
+++ b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
@@ -91,6 +91,101 @@ define <vscale x 4 x float> @test_compress_nxv4f32(<vscale x 4 x float> %vec, <v
ret <vscale x 4 x float> %out
}
+define <vscale x 8 x i8> @test_compress_nxv8i8(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv8i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: uunpklo z1.s, z0.h
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: addpl x9, sp, #4
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: cntp x8, p1, p1.s
+; CHECK-NEXT: compact z1.s, p1, z1.s
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: ptrue p1.h
+; CHECK-NEXT: st1b { z1.s }, p0, [sp, #2, mul vl]
+; CHECK-NEXT: st1b { z0.s }, p0, [x9, x8]
+; CHECK-NEXT: ld1b { z0.h }, p1/z, [sp, #1, mul vl]
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %out = call <vscale x 8 x i8> @llvm.experimental.vector.compress(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef)
+ ret <vscale x 8 x i8> %out
+}
+
+define <vscale x 8 x i16> @test_compress_nxv8i16(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv8i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: uunpklo z1.s, z0.h
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: cntp x8, p1, p1.s
+; CHECK-NEXT: compact z1.s, p1, z1.s
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: ptrue p1.h
+; CHECK-NEXT: st1h { z1.s }, p0, [sp]
+; CHECK-NEXT: st1h { z0.s }, p0, [x9, x8, lsl #1]
+; CHECK-NEXT: ld1h { z0.h }, p1/z, [sp]
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %out = call <vscale x 8 x i16> @llvm.experimental.vector.compress(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i16> undef)
+ ret <vscale x 8 x i16> %out
+}
+
+define <vscale x 16 x i8> @test_compress_nxv16i8(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv16i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: uunpklo z1.h, z0.b
+; CHECK-NEXT: punpklo p2.h, p0.b
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: uunpkhi z0.h, z0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: punpklo p3.h, p2.b
+; CHECK-NEXT: punpkhi p2.h, p2.b
+; CHECK-NEXT: uunpklo z2.s, z1.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: cntp x8, p3, p3.s
+; CHECK-NEXT: uunpklo z3.s, z0.h
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: compact z2.s, p3, z2.s
+; CHECK-NEXT: compact z1.s, p2, z1.s
+; CHECK-NEXT: punpklo p3.h, p0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: ptrue p0.b
+; CHECK-NEXT: st1b { z2.s }, p1, [sp]
+; CHECK-NEXT: st1b { z1.s }, p1, [x9, x8]
+; CHECK-NEXT: compact z1.s, p3, z3.s
+; CHECK-NEXT: incp x8, p2.s
+; CHECK-NEXT: st1b { z1.s }, p1, [x9, x8]
+; CHECK-NEXT: incp x8, p3.s
+; CHECK-NEXT: st1b { z0.s }, p1, [x9, x8]
+; CHECK-NEXT: ld1b { z0.b }, p0/z, [sp]
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %out = call <vscale x 16 x i8> @llvm.experimental.vector.compress(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
+ ret <vscale x 16 x i8> %out
+}
+
define <vscale x 4 x i4> @test_compress_illegal_element_type(<vscale x 4 x i4> %vec, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: test_compress_illegal_element_type:
; CHECK: // %bb.0:
@@ -240,6 +335,40 @@ define <2 x i16> @test_compress_v2i16_with_sve(<2 x i16> %vec, <2 x i1> %mask) {
ret <2 x i16> %out
}
+define <8 x i16> @test_compress_v8i16_with_sve(<8 x i16> %vec, <8 x i1> %mask) {
+; CHECK-LABEL: test_compress_v8i16_with_sve:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and z1.h, z1.h, #0x1
+; CHECK-NEXT: cmpne p1.h, p0/z, z1.h, #0
+; CHECK-NEXT: uunpklo z1.s, z0.h
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: punpklo p2.h, p1.b
+; CHECK-NEXT: punpkhi p1.h, p1.b
+; CHECK-NEXT: compact z1.s, p2, z1.s
+; CHECK-NEXT: cntp x8, p2, p2.s
+; CHECK-NEXT: compact z0.s, p1, z0.s
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: st1h { z1.s }, p1, [sp]
+; CHECK-NEXT: st1h { z0.s }, p1, [x9, x8, lsl #1]
+; CHECK-NEXT: ld1h { z0.h }, p0/z, [sp]
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %out = call <8 x i16> @llvm.experimental.vector.compress(<8 x i16> %vec, <8 x i1> %mask, <8 x i16> undef)
+ ret <8 x i16> %out
+}
+
define <vscale x 4 x i32> @test_compress_nxv4i32_with_passthru(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) {
; CHECK-LABEL: test_compress_nxv4i32_with_passthru:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add tests for f16/bf16.
I've found a bug during manual testing and I'm not sure yet what the issue is. So I'd put this PR is on hold for now, until I've fixed this. As I'm traveling for ~10 days, this will probably take two weeks or so. It looks like when I extract subvectors to use |
This is a follow-up to #101015. We now support
@llvm.experimental.vector.compress
for SVE types that don't map directly tocompact
, i.e.,<vscale x 8 x ..>
and<vscale x 16 x ..>
. We can also use this logic for corresponding NEON vectors.