Skip to content

Commit

Permalink
[RISCV] Emit VP strided load in mgather combine. NFCI (llvm#98112)
Browse files Browse the repository at this point in the history
This combine is a duplication of the transform in
RISCVGatherScatterLowering but at the SelectionDAG level, so similarly
to llvm#98111 we can replace the use of riscv_masked_strided_load with a VP
strided load.

Unlike llvm#98111 we don't require llvm#97800 or llvm#97798 since it only operates
on fixed vectors with a non-zero stride.
  • Loading branch information
lukel97 authored and aaryanshukla committed Jul 14, 2024
1 parent d9ce1d2 commit 5696d46
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17062,15 +17062,16 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
SDValue BasePtr = DAG.getNode(ISD::ADD, DL, PtrVT, MGN->getBasePtr(),
DAG.getConstant(Addend, DL, PtrVT));

SDVTList VTs = DAG.getVTList({VT, MVT::Other});
SDValue IntID =
DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
XLenVT);
SDValue Ops[] =
{MGN->getChain(), IntID, MGN->getPassThru(), BasePtr,
DAG.getConstant(StepNumerator, DL, XLenVT), MGN->getMask()};
return DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs,
Ops, VT, MGN->getMemOperand());
SDValue EVL = DAG.getElementCount(DL, Subtarget.getXLenVT(),
VT.getVectorElementCount());
SDValue StridedLoad =
DAG.getStridedLoadVP(VT, DL, MGN->getChain(), BasePtr,
DAG.getConstant(StepNumerator, DL, XLenVT),
MGN->getMask(), EVL, MGN->getMemOperand());
SDValue VPSelect = DAG.getNode(ISD::VP_SELECT, DL, VT, MGN->getMask(),
StridedLoad, MGN->getPassThru(), EVL);
return DAG.getMergeValues({VPSelect, SDValue(StridedLoad.getNode(), 1)},
DL);
}
}

Expand Down

0 comments on commit 5696d46

Please sign in to comment.