Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] Extend custom lowering for SVE types in @llvm.experimental.vector.compress #105515

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1781,16 +1781,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::v2f32, MVT::v4f32, MVT::v2f64})
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);

// We can lower types that have <vscale x {2|4}> elements to compact.
// We can lower all legal (or smaller) SVE types to `compact`.
for (auto VT :
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32,
MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
for (auto VT : {MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64, MVT::v2f32,
MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32})
MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32,
MVT::v8i8, MVT::v8i16, MVT::v16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// Histcnt is SVE2 only
Expand Down Expand Up @@ -6659,10 +6661,6 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
return SDValue();

// Only <vscale x {4|2} x {i32|i64}> supported for compact.
if (MinElmts != 2 && MinElmts != 4)
return SDValue();

// We can use the SVE register containing the NEON vector in its lowest bits.
if (IsFixedLength) {
EVT ScalableVecVT =
Expand Down Expand Up @@ -6690,16 +6688,67 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
EVT ContainerVT = getSVEContainerType(VecVT);
EVT CastVT = VecVT.changeVectorElementTypeToInteger();

// Convert to i32 or i64 for smaller types, as these are the only supported
// sizes for compact.
if (ContainerVT != VecVT) {
Vec = DAG.getBitcast(CastVT, Vec);
Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
}
// These vector types aren't supported by the `compact` instruction, so
// we split and compact them as <vscale x 4 x i32>, store them on the stack,
// and then merge them again. In the other cases, emit compact directly.
SDValue Compressed;
if (VecVT == MVT::nxv8i16 || VecVT == MVT::nxv8i8 || VecVT == MVT::nxv16i8) {
SDValue Chain = DAG.getEntryNode();
SDValue StackPtr = DAG.CreateStackTemporary(
VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
lawben marked this conversation as resolved.
Show resolved Hide resolved
MachineFunction &MF = DAG.getMachineFunction();

EVT PartialVecVT =
EVT::getVectorVT(*DAG.getContext(), ElmtVT, 4, /*isScalable*/ true);
EVT OffsetVT = getVectorIdxTy(DAG.getDataLayout());
SDValue Offset = DAG.getConstant(0, DL, OffsetVT);

for (unsigned I = 0; I < MinElmts; I += 4) {
SDValue PartialVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, PartialVecVT,
Vec, DAG.getVectorIdxConstant(I, DL));
PartialVec = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv4i32, PartialVec);

SDValue PartialMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv4i1,
Mask, DAG.getVectorIdxConstant(I, DL));

SDValue PartialCompressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32,
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64),
PartialMask, PartialVec);
PartialCompressed =
DAG.getNode(ISD::TRUNCATE, DL, PartialVecVT, PartialCompressed);

SDValue OutPtr = DAG.getNode(
ISD::ADD, DL, StackPtr.getValueType(), StackPtr,
DAG.getNode(
ISD::MUL, DL, OffsetVT, Offset,
DAG.getConstant(ElmtVT.getScalarSizeInBits() / 8, DL, OffsetVT)));
Chain = DAG.getStore(Chain, DL, PartialCompressed, OutPtr,
MachinePointerInfo::getUnknownStack(MF));

SDValue PartialOffset = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, OffsetVT,
DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64),
PartialMask, PartialMask);
Offset = DAG.getNode(ISD::ADD, DL, OffsetVT, Offset, PartialOffset);
}

MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());
Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
} else {
// Convert to i32 or i64 for smaller types, as these are the only supported
// sizes for compact.
if (ContainerVT != VecVT) {
Vec = DAG.getBitcast(CastVT, Vec);
Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
}

SDValue Compressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
Compressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask,
Vec);
}

// compact fills with 0s, so if our passthru is all 0s, do nothing here.
if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
lawben marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
129 changes: 129 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-vector-compress.ll
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,101 @@ define <vscale x 4 x float> @test_compress_nxv4f32(<vscale x 4 x float> %vec, <v
ret <vscale x 4 x float> %out
}

define <vscale x 8 x i8> @test_compress_nxv8i8(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask) {
; CHECK-LABEL: test_compress_nxv8i8:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: uunpklo z1.s, z0.h
; CHECK-NEXT: uunpkhi z0.s, z0.h
; CHECK-NEXT: addpl x9, sp, #4
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: cntp x8, p1, p1.s
; CHECK-NEXT: compact z1.s, p1, z1.s
; CHECK-NEXT: compact z0.s, p0, z0.s
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: ptrue p1.h
; CHECK-NEXT: st1b { z1.s }, p0, [sp, #2, mul vl]
; CHECK-NEXT: st1b { z0.s }, p0, [x9, x8]
; CHECK-NEXT: ld1b { z0.h }, p1/z, [sp, #1, mul vl]
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%out = call <vscale x 8 x i8> @llvm.experimental.vector.compress(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef)
ret <vscale x 8 x i8> %out
}

define <vscale x 8 x i16> @test_compress_nxv8i16(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask) {
; CHECK-LABEL: test_compress_nxv8i16:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: uunpklo z1.s, z0.h
; CHECK-NEXT: uunpkhi z0.s, z0.h
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: cntp x8, p1, p1.s
; CHECK-NEXT: compact z1.s, p1, z1.s
; CHECK-NEXT: compact z0.s, p0, z0.s
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: ptrue p1.h
; CHECK-NEXT: st1h { z1.s }, p0, [sp]
; CHECK-NEXT: st1h { z0.s }, p0, [x9, x8, lsl #1]
; CHECK-NEXT: ld1h { z0.h }, p1/z, [sp]
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%out = call <vscale x 8 x i16> @llvm.experimental.vector.compress(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i16> undef)
ret <vscale x 8 x i16> %out
}

define <vscale x 16 x i8> @test_compress_nxv16i8(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask) {
; CHECK-LABEL: test_compress_nxv16i8:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: uunpklo z1.h, z0.b
; CHECK-NEXT: punpklo p2.h, p0.b
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: uunpkhi z0.h, z0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: punpklo p3.h, p2.b
; CHECK-NEXT: punpkhi p2.h, p2.b
; CHECK-NEXT: uunpklo z2.s, z1.h
; CHECK-NEXT: uunpkhi z1.s, z1.h
; CHECK-NEXT: cntp x8, p3, p3.s
; CHECK-NEXT: uunpklo z3.s, z0.h
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: uunpkhi z0.s, z0.h
; CHECK-NEXT: compact z2.s, p3, z2.s
; CHECK-NEXT: compact z1.s, p2, z1.s
; CHECK-NEXT: punpklo p3.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: compact z0.s, p0, z0.s
; CHECK-NEXT: ptrue p0.b
; CHECK-NEXT: st1b { z2.s }, p1, [sp]
; CHECK-NEXT: st1b { z1.s }, p1, [x9, x8]
; CHECK-NEXT: compact z1.s, p3, z3.s
; CHECK-NEXT: incp x8, p2.s
; CHECK-NEXT: st1b { z1.s }, p1, [x9, x8]
; CHECK-NEXT: incp x8, p3.s
; CHECK-NEXT: st1b { z0.s }, p1, [x9, x8]
; CHECK-NEXT: ld1b { z0.b }, p0/z, [sp]
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%out = call <vscale x 16 x i8> @llvm.experimental.vector.compress(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
ret <vscale x 16 x i8> %out
}

define <vscale x 4 x i4> @test_compress_illegal_element_type(<vscale x 4 x i4> %vec, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: test_compress_illegal_element_type:
; CHECK: // %bb.0:
Expand Down Expand Up @@ -240,6 +335,40 @@ define <2 x i16> @test_compress_v2i16_with_sve(<2 x i16> %vec, <2 x i1> %mask) {
ret <2 x i16> %out
}

define <8 x i16> @test_compress_v8i16_with_sve(<8 x i16> %vec, <8 x i1> %mask) {
; CHECK-LABEL: test_compress_v8i16_with_sve:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: shl v1.8h, v1.8h, #15
; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
; CHECK-NEXT: and z1.h, z1.h, #0x1
; CHECK-NEXT: cmpne p1.h, p0/z, z1.h, #0
; CHECK-NEXT: uunpklo z1.s, z0.h
; CHECK-NEXT: uunpkhi z0.s, z0.h
; CHECK-NEXT: punpklo p2.h, p1.b
; CHECK-NEXT: punpkhi p1.h, p1.b
; CHECK-NEXT: compact z1.s, p2, z1.s
; CHECK-NEXT: cntp x8, p2, p2.s
; CHECK-NEXT: compact z0.s, p1, z0.s
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: st1h { z1.s }, p1, [sp]
; CHECK-NEXT: st1h { z0.s }, p1, [x9, x8, lsl #1]
; CHECK-NEXT: ld1h { z0.h }, p0/z, [sp]
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%out = call <8 x i16> @llvm.experimental.vector.compress(<8 x i16> %vec, <8 x i1> %mask, <8 x i16> undef)
ret <8 x i16> %out
}


define <vscale x 4 x i32> @test_compress_nxv4i32_with_passthru(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) {
; CHECK-LABEL: test_compress_nxv4i32_with_passthru:
Expand Down
Loading