Skip to content

Commit

Permalink
[RISCV] Lower memory ops and VP splat for zvfhmin and zvfbfmin
Browse files Browse the repository at this point in the history
We can lower f16/bf16 memory ops without promotion through the existing custom lowering.

Some of the zero strided VP loads get combined to a VP splat, so we need to also handle the lowering for that for f16/bf16 w/ zvfhmin/zvfbfmin. This patch copies the lowering from ISD::SPLAT_VECTOR over to lowerScalarSplat which is used by the VP splat lowering.
  • Loading branch information
lukel97 committed Sep 20, 2024
1 parent 400b725 commit d373bcc
Show file tree
Hide file tree
Showing 12 changed files with 1,518 additions and 155 deletions.
33 changes: 28 additions & 5 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1082,10 +1082,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
VT, Custom);
MVT EltVT = VT.getVectorElementType();
if (isTypeLegal(EltVT))
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT}, VT,
Custom);
else
setOperationAction(ISD::SPLAT_VECTOR, EltVT, Custom);
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT},
EltVT, Custom);
setOperationAction({ISD::LOAD, ISD::STORE, ISD::MLOAD, ISD::MSTORE,
ISD::MGATHER, ISD::MSCATTER, ISD::VP_LOAD,
ISD::VP_STORE, ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER,
ISD::VP_SCATTER},
VT, Custom);

setOperationAction(ISD::FNEG, VT, Expand);
setOperationAction(ISD::FABS, VT, Expand);
Expand Down Expand Up @@ -4449,11 +4456,27 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
bool HasPassthru = Passthru && !Passthru.isUndef();
if (!HasPassthru && !Passthru)
Passthru = DAG.getUNDEF(VT);
if (VT.isFloatingPoint())
return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, VT, Passthru, Scalar, VL);

MVT EltVT = VT.getVectorElementType();
MVT XLenVT = Subtarget.getXLenVT();

if (VT.isFloatingPoint()) {
if ((EltVT == MVT::f16 && !Subtarget.hasStdExtZvfh()) ||
EltVT == MVT::bf16) {
if ((EltVT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) ||
(EltVT == MVT::f16 && Subtarget.hasStdExtZfhmin()))
Scalar = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Scalar);
else
Scalar = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Scalar);
MVT IVT = VT.changeVectorElementType(MVT::i16);
Passthru = DAG.getNode(ISD::BITCAST, DL, IVT, Passthru);
SDValue Splat =
lowerScalarSplat(Passthru, Scalar, VL, IVT, DL, DAG, Subtarget);
return DAG.getNode(ISD::BITCAST, DL, VT, Splat);
}
return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, VT, Passthru, Scalar, VL);
}

// Simplest case is that the operand needs to be promoted to XLenVT.
if (Scalar.getValueType().bitsLE(XLenVT)) {
// If the operand is a constant, sign extend to increase our chances
Expand Down
72 changes: 70 additions & 2 deletions llvm/test/CodeGen/RISCV/rvv/masked-load-fp.ll
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+zvfbfmin,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+zvfbfmin,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+zvfbfmin,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+zvfbfmin,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s

define <vscale x 1 x bfloat> @masked_load_nxv1bf16(ptr %a, <vscale x 1 x i1> %mask) nounwind {
; CHECK-LABEL: masked_load_nxv1bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e16, mf4, ta, ma
; CHECK-NEXT: vle16.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 1 x bfloat> @llvm.masked.load.nxv1bf16(ptr %a, i32 2, <vscale x 1 x i1> %mask, <vscale x 1 x bfloat> undef)
ret <vscale x 1 x bfloat> %load
}
declare <vscale x 1 x bfloat> @llvm.masked.load.nxv1bf16(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x bfloat>)

define <vscale x 1 x half> @masked_load_nxv1f16(ptr %a, <vscale x 1 x i1> %mask) nounwind {
; CHECK-LABEL: masked_load_nxv1f16:
Expand Down Expand Up @@ -35,6 +48,17 @@ define <vscale x 1 x double> @masked_load_nxv1f64(ptr %a, <vscale x 1 x i1> %mas
}
declare <vscale x 1 x double> @llvm.masked.load.nxv1f64(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x double>)

define <vscale x 2 x bfloat> @masked_load_nxv2bf16(ptr %a, <vscale x 2 x i1> %mask) nounwind {
; CHECK-LABEL: masked_load_nxv2bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e16, mf2, ta, ma
; CHECK-NEXT: vle16.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 2 x bfloat> @llvm.masked.load.nxv2bf16(ptr %a, i32 2, <vscale x 2 x i1> %mask, <vscale x 2 x bfloat> undef)
ret <vscale x 2 x bfloat> %load
}
declare <vscale x 2 x bfloat> @llvm.masked.load.nxv2bf16(ptr, i32, <vscale x 2 x i1>, <vscale x 2 x bfloat>)

define <vscale x 2 x half> @masked_load_nxv2f16(ptr %a, <vscale x 2 x i1> %mask) nounwind {
; CHECK-LABEL: masked_load_nxv2f16:
; CHECK: # %bb.0:
Expand Down Expand Up @@ -68,6 +92,17 @@ define <vscale x 2 x double> @masked_load_nxv2f64(ptr %a, <vscale x 2 x i1> %mas
}
declare <vscale x 2 x double> @llvm.masked.load.nxv2f64(ptr, i32, <vscale x 2 x i1>, <vscale x 2 x double>)

define <vscale x 4 x bfloat> @masked_load_nxv4bf16(ptr %a, <vscale x 4 x i1> %mask) nounwind {
; CHECK-LABEL: masked_load_nxv4bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
; CHECK-NEXT: vle16.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 4 x bfloat> @llvm.masked.load.nxv4bf16(ptr %a, i32 2, <vscale x 4 x i1> %mask, <vscale x 4 x bfloat> undef)
ret <vscale x 4 x bfloat> %load
}
declare <vscale x 4 x bfloat> @llvm.masked.load.nxv4bf16(ptr, i32, <vscale x 4 x i1>, <vscale x 4 x bfloat>)

define <vscale x 4 x half> @masked_load_nxv4f16(ptr %a, <vscale x 4 x i1> %mask) nounwind {
; CHECK-LABEL: masked_load_nxv4f16:
; CHECK: # %bb.0:
Expand Down Expand Up @@ -101,6 +136,17 @@ define <vscale x 4 x double> @masked_load_nxv4f64(ptr %a, <vscale x 4 x i1> %mas
}
declare <vscale x 4 x double> @llvm.masked.load.nxv4f64(ptr, i32, <vscale x 4 x i1>, <vscale x 4 x double>)

define <vscale x 8 x bfloat> @masked_load_nxv8bf16(ptr %a, <vscale x 8 x i1> %mask) nounwind {
; CHECK-LABEL: masked_load_nxv8bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e16, m2, ta, ma
; CHECK-NEXT: vle16.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 8 x bfloat> @llvm.masked.load.nxv8bf16(ptr %a, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x bfloat> undef)
ret <vscale x 8 x bfloat> %load
}
declare <vscale x 8 x bfloat> @llvm.masked.load.nxv8bf16(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x bfloat>)

define <vscale x 8 x half> @masked_load_nxv8f16(ptr %a, <vscale x 8 x i1> %mask) nounwind {
; CHECK-LABEL: masked_load_nxv8f16:
; CHECK: # %bb.0:
Expand Down Expand Up @@ -134,6 +180,17 @@ define <vscale x 8 x double> @masked_load_nxv8f64(ptr %a, <vscale x 8 x i1> %mas
}
declare <vscale x 8 x double> @llvm.masked.load.nxv8f64(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x double>)

define <vscale x 16 x bfloat> @masked_load_nxv16bf16(ptr %a, <vscale x 16 x i1> %mask) nounwind {
; CHECK-LABEL: masked_load_nxv16bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e16, m4, ta, ma
; CHECK-NEXT: vle16.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 16 x bfloat> @llvm.masked.load.nxv16bf16(ptr %a, i32 2, <vscale x 16 x i1> %mask, <vscale x 16 x bfloat> undef)
ret <vscale x 16 x bfloat> %load
}
declare <vscale x 16 x bfloat> @llvm.masked.load.nxv16bf16(ptr, i32, <vscale x 16 x i1>, <vscale x 16 x bfloat>)

define <vscale x 16 x half> @masked_load_nxv16f16(ptr %a, <vscale x 16 x i1> %mask) nounwind {
; CHECK-LABEL: masked_load_nxv16f16:
; CHECK: # %bb.0:
Expand All @@ -156,6 +213,17 @@ define <vscale x 16 x float> @masked_load_nxv16f32(ptr %a, <vscale x 16 x i1> %m
}
declare <vscale x 16 x float> @llvm.masked.load.nxv16f32(ptr, i32, <vscale x 16 x i1>, <vscale x 16 x float>)

define <vscale x 32 x bfloat> @masked_load_nxv32bf16(ptr %a, <vscale x 32 x i1> %mask) nounwind {
; CHECK-LABEL: masked_load_nxv32bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e16, m8, ta, ma
; CHECK-NEXT: vle16.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 32 x bfloat> @llvm.masked.load.nxv32bf16(ptr %a, i32 2, <vscale x 32 x i1> %mask, <vscale x 32 x bfloat> undef)
ret <vscale x 32 x bfloat> %load
}
declare <vscale x 32 x bfloat> @llvm.masked.load.nxv32bf16(ptr, i32, <vscale x 32 x i1>, <vscale x 32 x bfloat>)

define <vscale x 32 x half> @masked_load_nxv32f16(ptr %a, <vscale x 32 x i1> %mask) nounwind {
; CHECK-LABEL: masked_load_nxv32f16:
; CHECK: # %bb.0:
Expand Down
72 changes: 70 additions & 2 deletions llvm/test/CodeGen/RISCV/rvv/masked-store-fp.ll
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+zvfbfmin,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+zvfbfmin,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+zvfbfmin,+v -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+zvfbfmin,+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s

define void @masked_store_nxv1bf16(<vscale x 1 x bfloat> %val, ptr %a, <vscale x 1 x i1> %mask) nounwind {
; CHECK-LABEL: masked_store_nxv1bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e16, mf4, ta, ma
; CHECK-NEXT: vse16.v v8, (a0), v0.t
; CHECK-NEXT: ret
call void @llvm.masked.store.nxv1bf16.p0(<vscale x 1 x bfloat> %val, ptr %a, i32 2, <vscale x 1 x i1> %mask)
ret void
}
declare void @llvm.masked.store.nxv1bf16.p0(<vscale x 1 x bfloat>, ptr, i32, <vscale x 1 x i1>)

define void @masked_store_nxv1f16(<vscale x 1 x half> %val, ptr %a, <vscale x 1 x i1> %mask) nounwind {
; CHECK-LABEL: masked_store_nxv1f16:
Expand Down Expand Up @@ -35,6 +48,17 @@ define void @masked_store_nxv1f64(<vscale x 1 x double> %val, ptr %a, <vscale x
}
declare void @llvm.masked.store.nxv1f64.p0(<vscale x 1 x double>, ptr, i32, <vscale x 1 x i1>)

define void @masked_store_nxv2bf16(<vscale x 2 x bfloat> %val, ptr %a, <vscale x 2 x i1> %mask) nounwind {
; CHECK-LABEL: masked_store_nxv2bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e16, mf2, ta, ma
; CHECK-NEXT: vse16.v v8, (a0), v0.t
; CHECK-NEXT: ret
call void @llvm.masked.store.nxv2bf16.p0(<vscale x 2 x bfloat> %val, ptr %a, i32 2, <vscale x 2 x i1> %mask)
ret void
}
declare void @llvm.masked.store.nxv2bf16.p0(<vscale x 2 x bfloat>, ptr, i32, <vscale x 2 x i1>)

define void @masked_store_nxv2f16(<vscale x 2 x half> %val, ptr %a, <vscale x 2 x i1> %mask) nounwind {
; CHECK-LABEL: masked_store_nxv2f16:
; CHECK: # %bb.0:
Expand Down Expand Up @@ -68,6 +92,17 @@ define void @masked_store_nxv2f64(<vscale x 2 x double> %val, ptr %a, <vscale x
}
declare void @llvm.masked.store.nxv2f64.p0(<vscale x 2 x double>, ptr, i32, <vscale x 2 x i1>)

define void @masked_store_nxv4bf16(<vscale x 4 x bfloat> %val, ptr %a, <vscale x 4 x i1> %mask) nounwind {
; CHECK-LABEL: masked_store_nxv4bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
; CHECK-NEXT: vse16.v v8, (a0), v0.t
; CHECK-NEXT: ret
call void @llvm.masked.store.nxv4bf16.p0(<vscale x 4 x bfloat> %val, ptr %a, i32 2, <vscale x 4 x i1> %mask)
ret void
}
declare void @llvm.masked.store.nxv4bf16.p0(<vscale x 4 x bfloat>, ptr, i32, <vscale x 4 x i1>)

define void @masked_store_nxv4f16(<vscale x 4 x half> %val, ptr %a, <vscale x 4 x i1> %mask) nounwind {
; CHECK-LABEL: masked_store_nxv4f16:
; CHECK: # %bb.0:
Expand Down Expand Up @@ -101,6 +136,17 @@ define void @masked_store_nxv4f64(<vscale x 4 x double> %val, ptr %a, <vscale x
}
declare void @llvm.masked.store.nxv4f64.p0(<vscale x 4 x double>, ptr, i32, <vscale x 4 x i1>)

define void @masked_store_nxv8bf16(<vscale x 8 x bfloat> %val, ptr %a, <vscale x 8 x i1> %mask) nounwind {
; CHECK-LABEL: masked_store_nxv8bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e16, m2, ta, ma
; CHECK-NEXT: vse16.v v8, (a0), v0.t
; CHECK-NEXT: ret
call void @llvm.masked.store.nxv8bf16.p0(<vscale x 8 x bfloat> %val, ptr %a, i32 2, <vscale x 8 x i1> %mask)
ret void
}
declare void @llvm.masked.store.nxv8bf16.p0(<vscale x 8 x bfloat>, ptr, i32, <vscale x 8 x i1>)

define void @masked_store_nxv8f16(<vscale x 8 x half> %val, ptr %a, <vscale x 8 x i1> %mask) nounwind {
; CHECK-LABEL: masked_store_nxv8f16:
; CHECK: # %bb.0:
Expand Down Expand Up @@ -134,6 +180,17 @@ define void @masked_store_nxv8f64(<vscale x 8 x double> %val, ptr %a, <vscale x
}
declare void @llvm.masked.store.nxv8f64.p0(<vscale x 8 x double>, ptr, i32, <vscale x 8 x i1>)

define void @masked_store_nxv16bf16(<vscale x 16 x bfloat> %val, ptr %a, <vscale x 16 x i1> %mask) nounwind {
; CHECK-LABEL: masked_store_nxv16bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e16, m4, ta, ma
; CHECK-NEXT: vse16.v v8, (a0), v0.t
; CHECK-NEXT: ret
call void @llvm.masked.store.nxv16bf16.p0(<vscale x 16 x bfloat> %val, ptr %a, i32 2, <vscale x 16 x i1> %mask)
ret void
}
declare void @llvm.masked.store.nxv16bf16.p0(<vscale x 16 x bfloat>, ptr, i32, <vscale x 16 x i1>)

define void @masked_store_nxv16f16(<vscale x 16 x half> %val, ptr %a, <vscale x 16 x i1> %mask) nounwind {
; CHECK-LABEL: masked_store_nxv16f16:
; CHECK: # %bb.0:
Expand All @@ -156,6 +213,17 @@ define void @masked_store_nxv16f32(<vscale x 16 x float> %val, ptr %a, <vscale x
}
declare void @llvm.masked.store.nxv16f32.p0(<vscale x 16 x float>, ptr, i32, <vscale x 16 x i1>)

define void @masked_store_nxv32bf16(<vscale x 32 x bfloat> %val, ptr %a, <vscale x 32 x i1> %mask) nounwind {
; CHECK-LABEL: masked_store_nxv32bf16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e16, m8, ta, ma
; CHECK-NEXT: vse16.v v8, (a0), v0.t
; CHECK-NEXT: ret
call void @llvm.masked.store.nxv32bf16.p0(<vscale x 32 x bfloat> %val, ptr %a, i32 2, <vscale x 32 x i1> %mask)
ret void
}
declare void @llvm.masked.store.nxv32bf16.p0(<vscale x 32 x bfloat>, ptr, i32, <vscale x 32 x i1>)

define void @masked_store_nxv32f16(<vscale x 32 x half> %val, ptr %a, <vscale x 32 x i1> %mask) nounwind {
; CHECK-LABEL: masked_store_nxv32f16:
; CHECK: # %bb.0:
Expand Down
Loading

0 comments on commit d373bcc

Please sign in to comment.