Skip to content

Commit

Permalink
[RISCV] Support select/merge like ops for bf16 vectors when have Zvfb…
Browse files Browse the repository at this point in the history
…fmin (#91936)
  • Loading branch information
jacquesguan authored Jun 6, 2024
1 parent 4b70294 commit d5ab38f
Show file tree
Hide file tree
Showing 11 changed files with 1,150 additions and 38 deletions.
32 changes: 28 additions & 4 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::EXTRACT_SUBVECTOR},
VT, Custom);
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
if (Subtarget.hasStdExtZfbfmin()) {
if (Subtarget.hasVInstructionsF16())
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
else if (Subtarget.hasVInstructionsF16Minimal())
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
}
setOperationAction({ISD::VP_MERGE, ISD::VP_SELECT, ISD::SELECT}, VT,
Custom);
setOperationAction(ISD::SELECT_CC, VT, Expand);
// TODO: Promote to fp32.
}
}
Expand Down Expand Up @@ -1331,6 +1340,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::EXTRACT_SUBVECTOR},
VT, Custom);
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
if (Subtarget.hasStdExtZfbfmin()) {
if (Subtarget.hasVInstructionsF16())
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
else if (Subtarget.hasVInstructionsF16Minimal())
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
}
setOperationAction(
{ISD::VP_MERGE, ISD::VP_SELECT, ISD::VSELECT, ISD::SELECT}, VT,
Custom);
// TODO: Promote to fp32.
continue;
}
Expand Down Expand Up @@ -6704,10 +6722,16 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::BUILD_VECTOR:
return lowerBUILD_VECTOR(Op, DAG, Subtarget);
case ISD::SPLAT_VECTOR:
if (Op.getValueType().getScalarType() == MVT::f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16())) {
if (Op.getValueType() == MVT::nxv32f16)
if ((Op.getValueType().getScalarType() == MVT::f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
Subtarget.hasStdExtZfhminOrZhinxmin() &&
!Subtarget.hasVInstructionsF16())) ||
(Op.getValueType().getScalarType() == MVT::bf16 &&
(Subtarget.hasVInstructionsBF16() && Subtarget.hasStdExtZfbfmin() &&
Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))) {
if (Op.getValueType() == MVT::nxv32f16 ||
Op.getValueType() == MVT::nxv32bf16)
return SplitVectorOp(Op, DAG);
SDLoc DL(Op);
SDValue NewScalar =
Expand Down
15 changes: 14 additions & 1 deletion llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,20 @@ class GetIntVTypeInfo<VTypeInfo vti> {
// Equivalent integer vector type. Eg.
// VI8M1 → VI8M1 (identity)
// VF64M4 → VI64M4
VTypeInfo Vti = !cast<VTypeInfo>(!subst("VF", "VI", !cast<string>(vti)));
VTypeInfo Vti = !cast<VTypeInfo>(!subst("VBF", "VI",
!subst("VF", "VI",
!cast<string>(vti))));
}

// This functor is used to obtain the fp vector type that has the same SEW and
// multiplier as the input parameter type.
class GetFpVTypeInfo<VTypeInfo vti> {
// Equivalent integer vector type. Eg.
// VF16M1 → VF16M1 (identity)
// VBF16M1 → VF16M1
VTypeInfo Vti = !cast<VTypeInfo>(!subst("VBF", "VF",
!subst("VI", "VF",
!cast<string>(vti))));
}

class MTypeInfo<ValueType Mas, LMULInfo M, string Bx> {
Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,7 @@ defm : VPatFPSetCCSDNode_VV_VF_FV<SETOLE, "PseudoVMFLE", "PseudoVMFGE">;
// Floating-point vselects:
// 11.15. Vector Integer Merge Instructions
// 13.15. Vector Floating-Point Merge Instruction
foreach fvti = AllFloatVectors in {
foreach fvti = !listconcat(AllFloatVectors, AllBFloatVectors) in {
defvar ivti = GetIntVTypeInfo<fvti>.Vti;
let Predicates = GetVTypePredicates<ivti>.Predicates in {
def : Pat<(fvti.Vector (vselect (fvti.Mask V0), fvti.RegClass:$rs1,
Expand All @@ -1412,7 +1412,9 @@ foreach fvti = AllFloatVectors in {
fvti.RegClass:$rs2, 0, (fvti.Mask V0), fvti.AVL, fvti.Log2SEW)>;

}
let Predicates = GetVTypePredicates<fvti>.Predicates in

let Predicates = !listconcat(GetVTypePredicates<GetFpVTypeInfo<fvti>.Vti>.Predicates,
GetVTypeScalarPredicates<fvti>.Predicates) in
def : Pat<(fvti.Vector (vselect (fvti.Mask V0),
(SplatFPOp fvti.ScalarRegClass:$rs1),
fvti.RegClass:$rs2)),
Expand Down Expand Up @@ -1475,7 +1477,7 @@ foreach fvtiToFWti = AllWidenableBFloatToFloatVectors in {
//===----------------------------------------------------------------------===//

foreach fvti = !listconcat(AllFloatVectors, AllBFloatVectors) in {
let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates,
let Predicates = !listconcat(GetVTypePredicates<GetFpVTypeInfo<fvti>.Vti>.Predicates,
GetVTypeScalarPredicates<fvti>.Predicates) in
def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl undef, fvti.ScalarRegClass:$rs1, srcvalue)),
(!cast<Instruction>("PseudoVFMV_V_"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -2604,7 +2604,7 @@ foreach vti = AllFloatVectors in {
}
}

foreach fvti = AllFloatVectors in {
foreach fvti = !listconcat(AllFloatVectors, AllBFloatVectors) in {
// Floating-point vselects:
// 11.15. Vector Integer Merge Instructions
// 13.15. Vector Floating-Point Merge Instruction
Expand Down Expand Up @@ -2639,7 +2639,8 @@ foreach fvti = AllFloatVectors in {
GPR:$vl, fvti.Log2SEW)>;
}

let Predicates = GetVTypePredicates<fvti>.Predicates in {
let Predicates = !listconcat(GetVTypePredicates<GetFpVTypeInfo<fvti>.Vti>.Predicates,
GetVTypeScalarPredicates<fvti>.Predicates) in {
def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
(SplatFPOp fvti.ScalarRegClass:$rs1),
fvti.RegClass:$rs2,
Expand Down
128 changes: 124 additions & 4 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-select-fp.ll
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v -target-abi=ilp32d \
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=ilp32d \
; RUN: -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v -target-abi=lp64d \
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=lp64d \
; RUN: -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+v,+m -target-abi=ilp32d -riscv-v-vector-bits-min=128 \
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+v,+m,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=ilp32d -riscv-v-vector-bits-min=128 \
; RUN: -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+v,+m -target-abi=lp64d -riscv-v-vector-bits-min=128 \
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+v,+m,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=lp64d -riscv-v-vector-bits-min=128 \
; RUN: -verify-machineinstrs < %s | FileCheck %s

define <2 x half> @select_v2f16(i1 zeroext %c, <2 x half> %a, <2 x half> %b) {
Expand Down Expand Up @@ -343,3 +343,123 @@ define <16 x double> @selectcc_v16f64(double %a, double %b, <16 x double> %c, <1
%v = select i1 %cmp, <16 x double> %c, <16 x double> %d
ret <16 x double> %v
}

define <2 x bfloat> @select_v2bf16(i1 zeroext %c, <2 x bfloat> %a, <2 x bfloat> %b) {
; CHECK-LABEL: select_v2bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
; CHECK-NEXT: vmv.v.x v10, a0
; CHECK-NEXT: vmsne.vi v0, v10, 0
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0
; CHECK-NEXT: ret
%v = select i1 %c, <2 x bfloat> %a, <2 x bfloat> %b
ret <2 x bfloat> %v
}

define <2 x bfloat> @selectcc_v2bf16(bfloat %a, bfloat %b, <2 x bfloat> %c, <2 x bfloat> %d) {
; CHECK-LABEL: selectcc_v2bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: fcvt.s.bf16 fa5, fa1
; CHECK-NEXT: fcvt.s.bf16 fa4, fa0
; CHECK-NEXT: feq.s a0, fa4, fa5
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
; CHECK-NEXT: vmv.v.x v10, a0
; CHECK-NEXT: vmsne.vi v0, v10, 0
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0
; CHECK-NEXT: ret
%cmp = fcmp oeq bfloat %a, %b
%v = select i1 %cmp, <2 x bfloat> %c, <2 x bfloat> %d
ret <2 x bfloat> %v
}

define <4 x bfloat> @select_v4bf16(i1 zeroext %c, <4 x bfloat> %a, <4 x bfloat> %b) {
; CHECK-LABEL: select_v4bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
; CHECK-NEXT: vmv.v.x v10, a0
; CHECK-NEXT: vmsne.vi v0, v10, 0
; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma
; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0
; CHECK-NEXT: ret
%v = select i1 %c, <4 x bfloat> %a, <4 x bfloat> %b
ret <4 x bfloat> %v
}

define <4 x bfloat> @selectcc_v4bf16(bfloat %a, bfloat %b, <4 x bfloat> %c, <4 x bfloat> %d) {
; CHECK-LABEL: selectcc_v4bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: fcvt.s.bf16 fa5, fa1
; CHECK-NEXT: fcvt.s.bf16 fa4, fa0
; CHECK-NEXT: feq.s a0, fa4, fa5
; CHECK-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
; CHECK-NEXT: vmv.v.x v10, a0
; CHECK-NEXT: vmsne.vi v0, v10, 0
; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma
; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0
; CHECK-NEXT: ret
%cmp = fcmp oeq bfloat %a, %b
%v = select i1 %cmp, <4 x bfloat> %c, <4 x bfloat> %d
ret <4 x bfloat> %v
}

define <8 x bfloat> @select_v8bf16(i1 zeroext %c, <8 x bfloat> %a, <8 x bfloat> %b) {
; CHECK-LABEL: select_v8bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vmv.v.x v10, a0
; CHECK-NEXT: vmsne.vi v0, v10, 0
; CHECK-NEXT: vsetvli zero, zero, e16, m1, ta, ma
; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0
; CHECK-NEXT: ret
%v = select i1 %c, <8 x bfloat> %a, <8 x bfloat> %b
ret <8 x bfloat> %v
}

define <8 x bfloat> @selectcc_v8bf16(bfloat %a, bfloat %b, <8 x bfloat> %c, <8 x bfloat> %d) {
; CHECK-LABEL: selectcc_v8bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: fcvt.s.bf16 fa5, fa1
; CHECK-NEXT: fcvt.s.bf16 fa4, fa0
; CHECK-NEXT: feq.s a0, fa4, fa5
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vmv.v.x v10, a0
; CHECK-NEXT: vmsne.vi v0, v10, 0
; CHECK-NEXT: vsetvli zero, zero, e16, m1, ta, ma
; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0
; CHECK-NEXT: ret
%cmp = fcmp oeq bfloat %a, %b
%v = select i1 %cmp, <8 x bfloat> %c, <8 x bfloat> %d
ret <8 x bfloat> %v
}

define <16 x bfloat> @select_v16bf16(i1 zeroext %c, <16 x bfloat> %a, <16 x bfloat> %b) {
; CHECK-LABEL: select_v16bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
; CHECK-NEXT: vmv.v.x v12, a0
; CHECK-NEXT: vmsne.vi v0, v12, 0
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
; CHECK-NEXT: vmerge.vvm v8, v10, v8, v0
; CHECK-NEXT: ret
%v = select i1 %c, <16 x bfloat> %a, <16 x bfloat> %b
ret <16 x bfloat> %v
}

define <16 x bfloat> @selectcc_v16bf16(bfloat %a, bfloat %b, <16 x bfloat> %c, <16 x bfloat> %d) {
; CHECK-LABEL: selectcc_v16bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: fcvt.s.bf16 fa5, fa1
; CHECK-NEXT: fcvt.s.bf16 fa4, fa0
; CHECK-NEXT: feq.s a0, fa4, fa5
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
; CHECK-NEXT: vmv.v.x v12, a0
; CHECK-NEXT: vmsne.vi v0, v12, 0
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
; CHECK-NEXT: vmerge.vvm v8, v10, v8, v0
; CHECK-NEXT: ret
%cmp = fcmp oeq bfloat %a, %b
%v = select i1 %cmp, <16 x bfloat> %c, <16 x bfloat> %d
ret <16 x bfloat> %v
}
Loading

0 comments on commit d5ab38f

Please sign in to comment.