Skip to content

Commit

Permalink
[GlobalIsel] Combine G_ADD and G_SUB with constants (#97771)
Browse files Browse the repository at this point in the history
  • Loading branch information
Thorsten Schütt authored Aug 9, 2024
1 parent edf45e4 commit 6b77531
Show file tree
Hide file tree
Showing 8 changed files with 399 additions and 23 deletions.
10 changes: 10 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,16 @@ class CombinerHelper {

bool matchCastOfSelect(const MachineInstr &Cast, const MachineInstr &SelectMI,
BuildFnTy &MatchInfo);
bool matchFoldAPlusC1MinusC2(const MachineInstr &MI, BuildFnTy &MatchInfo);

bool matchFoldC2MinusAPlusC1(const MachineInstr &MI, BuildFnTy &MatchInfo);

bool matchFoldAMinusC1MinusC2(const MachineInstr &MI, BuildFnTy &MatchInfo);

bool matchFoldC1Minus2MinusC2(const MachineInstr &MI, BuildFnTy &MatchInfo);

// fold ((A-C1)+C2) -> (A+(C2-C1))
bool matchFoldAMinusC1PlusC2(const MachineInstr &MI, BuildFnTy &MatchInfo);

private:
/// Checks for legality of an indexed variant of \p LdSt.
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "llvm/IR/DebugLoc.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/Casting.h"

#include <cstdint>

namespace llvm {
Expand Down Expand Up @@ -178,6 +179,9 @@ std::optional<APInt> getIConstantVRegVal(Register VReg,
std::optional<int64_t> getIConstantVRegSExtVal(Register VReg,
const MachineRegisterInfo &MRI);

/// \p VReg is defined by a G_CONSTANT, return the corresponding value.
APInt getIConstantFromReg(Register VReg, const MachineRegisterInfo &MRI);

/// Simple struct used to hold a constant integer value and a virtual
/// register.
struct ValueAndVReg {
Expand Down
57 changes: 56 additions & 1 deletion llvm/include/llvm/Target/GlobalISel/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -1747,6 +1747,56 @@ def APlusBMinusCPlusA : GICombineRule<
(G_ADD $root, $A, $sub1)),
(apply (G_SUB $root, $B, $C))>;

// fold (A+C1)-C2 -> A+(C1-C2)
def APlusC1MinusC2: GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_CONSTANT $c2, $imm2),
(G_CONSTANT $c1, $imm1),
(G_ADD $add, $A, $c1),
(G_SUB $root, $add, $c2):$root,
[{ return Helper.matchFoldAPlusC1MinusC2(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

// fold C2-(A+C1) -> (C2-C1)-A
def C2MinusAPlusC1: GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_CONSTANT $c2, $imm2),
(G_CONSTANT $c1, $imm1),
(G_ADD $add, $A, $c1),
(G_SUB $root, $c2, $add):$root,
[{ return Helper.matchFoldC2MinusAPlusC1(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

// fold (A-C1)-C2 -> A-(C1+C2)
def AMinusC1MinusC2: GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_CONSTANT $c2, $imm2),
(G_CONSTANT $c1, $imm1),
(G_SUB $sub1, $A, $c1),
(G_SUB $root, $sub1, $c2):$root,
[{ return Helper.matchFoldAMinusC1MinusC2(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

// fold (C1-A)-C2 -> (C1-C2)-A
def C1Minus2MinusC2: GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_CONSTANT $c2, $imm2),
(G_CONSTANT $c1, $imm1),
(G_SUB $sub1, $c1, $A),
(G_SUB $root, $sub1, $c2):$root,
[{ return Helper.matchFoldC1Minus2MinusC2(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

// fold ((A-C1)+C2) -> (A+(C2-C1))
def AMinusC1PlusC2: GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_CONSTANT $c2, $imm2),
(G_CONSTANT $c1, $imm1),
(G_SUB $sub, $A, $c1),
(G_ADD $root, $sub, $c2):$root,
[{ return Helper.matchFoldAMinusC1PlusC2(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

def integer_reassoc_combines: GICombineGroup<[
ZeroMinusAPlusB,
APlusZeroMinusB,
Expand All @@ -1755,7 +1805,12 @@ def integer_reassoc_combines: GICombineGroup<[
AMinusBPlusCMinusA,
AMinusBPlusBMinusC,
APlusBMinusAplusC,
APlusBMinusCPlusA
APlusBMinusCPlusA,
APlusC1MinusC2,
C2MinusAPlusC1,
AMinusC1MinusC2,
C1Minus2MinusC2,
AMinusC1PlusC2
]>;

def freeze_of_non_undef_non_poison : GICombineRule<
Expand Down
115 changes: 115 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7433,3 +7433,118 @@ void CombinerHelper::applyExpandFPowI(MachineInstr &MI, int64_t Exponent) {
Builder.buildCopy(Dst, *Res);
MI.eraseFromParent();
}

bool CombinerHelper::matchFoldAPlusC1MinusC2(const MachineInstr &MI,
BuildFnTy &MatchInfo) {
// fold (A+C1)-C2 -> A+(C1-C2)
const GSub *Sub = cast<GSub>(&MI);
GAdd *Add = cast<GAdd>(MRI.getVRegDef(Sub->getLHSReg()));

if (!MRI.hasOneNonDBGUse(Add->getReg(0)))
return false;

APInt C2 = getIConstantFromReg(Sub->getRHSReg(), MRI);
APInt C1 = getIConstantFromReg(Add->getRHSReg(), MRI);

Register Dst = Sub->getReg(0);
LLT DstTy = MRI.getType(Dst);

MatchInfo = [=](MachineIRBuilder &B) {
auto Const = B.buildConstant(DstTy, C1 - C2);
B.buildAdd(Dst, Add->getLHSReg(), Const);
};

return true;
}

bool CombinerHelper::matchFoldC2MinusAPlusC1(const MachineInstr &MI,
BuildFnTy &MatchInfo) {
// fold C2-(A+C1) -> (C2-C1)-A
const GSub *Sub = cast<GSub>(&MI);
GAdd *Add = cast<GAdd>(MRI.getVRegDef(Sub->getRHSReg()));

if (!MRI.hasOneNonDBGUse(Add->getReg(0)))
return false;

APInt C2 = getIConstantFromReg(Sub->getLHSReg(), MRI);
APInt C1 = getIConstantFromReg(Add->getRHSReg(), MRI);

Register Dst = Sub->getReg(0);
LLT DstTy = MRI.getType(Dst);

MatchInfo = [=](MachineIRBuilder &B) {
auto Const = B.buildConstant(DstTy, C2 - C1);
B.buildSub(Dst, Const, Add->getLHSReg());
};

return true;
}

bool CombinerHelper::matchFoldAMinusC1MinusC2(const MachineInstr &MI,
BuildFnTy &MatchInfo) {
// fold (A-C1)-C2 -> A-(C1+C2)
const GSub *Sub1 = cast<GSub>(&MI);
GSub *Sub2 = cast<GSub>(MRI.getVRegDef(Sub1->getLHSReg()));

if (!MRI.hasOneNonDBGUse(Sub2->getReg(0)))
return false;

APInt C2 = getIConstantFromReg(Sub1->getRHSReg(), MRI);
APInt C1 = getIConstantFromReg(Sub2->getRHSReg(), MRI);

Register Dst = Sub1->getReg(0);
LLT DstTy = MRI.getType(Dst);

MatchInfo = [=](MachineIRBuilder &B) {
auto Const = B.buildConstant(DstTy, C1 + C2);
B.buildSub(Dst, Sub2->getLHSReg(), Const);
};

return true;
}

bool CombinerHelper::matchFoldC1Minus2MinusC2(const MachineInstr &MI,
BuildFnTy &MatchInfo) {
// fold (C1-A)-C2 -> (C1-C2)-A
const GSub *Sub1 = cast<GSub>(&MI);
GSub *Sub2 = cast<GSub>(MRI.getVRegDef(Sub1->getLHSReg()));

if (!MRI.hasOneNonDBGUse(Sub2->getReg(0)))
return false;

APInt C2 = getIConstantFromReg(Sub1->getRHSReg(), MRI);
APInt C1 = getIConstantFromReg(Sub2->getLHSReg(), MRI);

Register Dst = Sub1->getReg(0);
LLT DstTy = MRI.getType(Dst);

MatchInfo = [=](MachineIRBuilder &B) {
auto Const = B.buildConstant(DstTy, C1 - C2);
B.buildSub(Dst, Const, Sub2->getRHSReg());
};

return true;
}

bool CombinerHelper::matchFoldAMinusC1PlusC2(const MachineInstr &MI,
BuildFnTy &MatchInfo) {
// fold ((A-C1)+C2) -> (A+(C2-C1))
const GAdd *Add = cast<GAdd>(&MI);
GSub *Sub = cast<GSub>(MRI.getVRegDef(Add->getLHSReg()));

if (!MRI.hasOneNonDBGUse(Sub->getReg(0)))
return false;

APInt C2 = getIConstantFromReg(Add->getRHSReg(), MRI);
APInt C1 = getIConstantFromReg(Sub->getRHSReg(), MRI);

Register Dst = Add->getReg(0);
LLT DstTy = MRI.getType(Dst);

MatchInfo = [=](MachineIRBuilder &B) {
auto Const = B.buildConstant(DstTy, C2 - C1);
B.buildAdd(Dst, Sub->getLHSReg(), Const);
};

return true;
}
7 changes: 7 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,13 @@ std::optional<APInt> llvm::getIConstantVRegVal(Register VReg,
return ValAndVReg->Value;
}

APInt llvm::getIConstantFromReg(Register Reg, const MachineRegisterInfo &MRI) {
MachineInstr *Const = MRI.getVRegDef(Reg);
assert((Const && Const->getOpcode() == TargetOpcode::G_CONSTANT) &&
"expected a G_CONSTANT on Reg");
return Const->getOperand(1).getCImm()->getValue();
}

std::optional<int64_t>
llvm::getIConstantVRegSExtVal(Register VReg, const MachineRegisterInfo &MRI) {
std::optional<APInt> Val = getIConstantVRegVal(VReg, MRI);
Expand Down
Loading

0 comments on commit 6b77531

Please sign in to comment.