Skip to content

Commit

Permalink
[DAG] SDPatternMatch - Add m_ExtractElt and m_InsertElt matchers (#11…
Browse files Browse the repository at this point in the history
…9430)

Resolves #118844
  • Loading branch information
whiteio authored Dec 13, 2024
1 parent 71d2fa7 commit ecdf0da
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
12 changes: 12 additions & 0 deletions llvm/include/llvm/CodeGen/SDPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,13 @@ m_VSelect(const T0_P &Cond, const T1_P &T, const T2_P &F) {
return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::VSELECT, Cond, T, F);
}

template <typename T0_P, typename T1_P, typename T2_P>
inline TernaryOpc_match<T0_P, T1_P, T2_P>
m_InsertElt(const T0_P &Vec, const T1_P &Val, const T2_P &Idx) {
return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::INSERT_VECTOR_ELT, Vec, Val,
Idx);
}

// === Binary operations ===
template <typename LHS_P, typename RHS_P, bool Commutable = false,
bool ExcludeChain = false>
Expand Down Expand Up @@ -790,6 +797,11 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
}

template <typename LHS, typename RHS>
inline BinaryOpc_match<LHS, RHS> m_ExtractElt(const LHS &Vec, const RHS &Idx) {
return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_VECTOR_ELT, Vec, Idx);
}

// === Unary operations ===
template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
unsigned Opcode;
Expand Down
22 changes: 22 additions & 0 deletions llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {

SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
SDValue Op3 = DAG->getConstant(1, DL, Int32VT);

SDValue ICMP_UGT = DAG->getSetCC(DL, MVT::i1, Op0, Op1, ISD::SETUGT);
SDValue ICMP_EQ01 = DAG->getSetCC(DL, MVT::i1, Op0, Op1, ISD::SETEQ);
Expand All @@ -141,6 +142,9 @@ TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {
SDValue V2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 7, VInt32VT);
SDValue VSelect = DAG->getNode(ISD::VSELECT, DL, VInt32VT, Cond, V1, V2);

SDValue ExtractELT =
DAG->getNode(ISD::EXTRACT_VECTOR_ELT, DL, Int32VT, V1, Op3);

using namespace SDPatternMatch;
ISD::CondCode CC;
EXPECT_TRUE(sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(),
Expand Down Expand Up @@ -174,17 +178,25 @@ TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {
VSelect, m_VSelect(m_Specific(Cond), m_Specific(V1), m_Specific(V2))));
EXPECT_FALSE(sd_match(
Select, m_VSelect(m_Specific(Cond), m_Specific(V1), m_Specific(V2))));

EXPECT_TRUE(sd_match(ExtractELT, m_ExtractElt(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(ExtractELT, m_ExtractElt(m_Value(), m_ConstInt())));
EXPECT_TRUE(sd_match(ExtractELT, m_ExtractElt(m_Value(), m_SpecificInt(1))));
}

TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
SDLoc DL;
auto Int32VT = EVT::getIntegerVT(Context, 32);
auto Float32VT = EVT::getFloatingPointVT(32);
auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);

SDValue V1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 6, VInt32VT);

SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
SDValue Op4 = DAG->getConstant(1, DL, Int32VT);

SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
Expand Down Expand Up @@ -221,6 +233,9 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
SDValue SFAdd = DAG->getNode(ISD::STRICT_FADD, DL, {Float32VT, MVT::Other},
{DAG->getEntryNode(), Op2, Op2});

SDValue InsertELT =
DAG->getNode(ISD::INSERT_VECTOR_ELT, DL, VInt32VT, V1, Op0, Op4);

using namespace SDPatternMatch;
EXPECT_TRUE(sd_match(Sub, m_BinOp(ISD::SUB, m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value())));
Expand Down Expand Up @@ -277,6 +292,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
m_Deferred(BindVal))));
EXPECT_FALSE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_OtherVT(),
m_SpecificVT(Float32VT))));

EXPECT_TRUE(
sd_match(InsertELT, m_InsertElt(m_Value(), m_Value(), m_Value())));
EXPECT_TRUE(
sd_match(InsertELT, m_InsertElt(m_Value(), m_Value(), m_ConstInt())));
EXPECT_TRUE(
sd_match(InsertELT, m_InsertElt(m_Value(), m_Value(), m_SpecificInt(1))));
}

TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
Expand Down

0 comments on commit ecdf0da

Please sign in to comment.