Skip to content

Commit

Permalink
[InstCombine] Simplify (add/sub (sub/add) (sub/add)) irrelivant of …
Browse files Browse the repository at this point in the history
…use-count

Added folds:
    - `(add (sub X, Y), (sub Z, X))` -> `(sub Z, Y)`
    - `(sub (add X, Y), (add X, Z))` -> `(sub Y, Z)`

The fold typically is handled in the `Reassosiate` pass, but it fails
if the inner `sub`/`add` are multi-use. Less importantly, Reassosiate
doesn't propagate flags correctly.

This patch adds the fold explicitly the InstCombine

Proofs: https://alive2.llvm.org/ce/z/p6JyRP

Closes llvm#105866
  • Loading branch information
goldsteinn authored and 5c4lar committed Aug 29, 2024
1 parent ed76c40 commit bac00b0
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 10 deletions.
51 changes: 51 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,24 @@ static Instruction *foldAddToAshr(BinaryOperator &Add) {
X, ConstantInt::get(Add.getType(), DivC->exactLogBase2()));
}

Instruction *InstCombinerImpl::foldAddLikeCommutative(Value *LHS, Value *RHS,
bool NSW, bool NUW) {
Value *A, *B, *C;
if (match(LHS, m_Sub(m_Value(A), m_Value(B))) &&
match(RHS, m_Sub(m_Value(C), m_Specific(A)))) {
Instruction *R = BinaryOperator::CreateSub(C, B);
bool NSWOut = NSW && match(LHS, m_NSWSub(m_Value(), m_Value())) &&
match(RHS, m_NSWSub(m_Value(), m_Value()));

bool NUWOut = match(LHS, m_NUWSub(m_Value(), m_Value())) &&
match(RHS, m_NUWSub(m_Value(), m_Value()));
R->setHasNoSignedWrap(NSWOut);
R->setHasNoUnsignedWrap(NUWOut);
return R;
}
return nullptr;
}

Instruction *InstCombinerImpl::
canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(
BinaryOperator &I) {
Expand Down Expand Up @@ -1521,6 +1539,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return R;

Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
if (Instruction *R = foldAddLikeCommutative(LHS, RHS, I.hasNoSignedWrap(),
I.hasNoUnsignedWrap()))
return R;
if (Instruction *R = foldAddLikeCommutative(RHS, LHS, I.hasNoSignedWrap(),
I.hasNoUnsignedWrap()))
return R;
Type *Ty = I.getType();
if (Ty->isIntOrIntVectorTy(1))
return BinaryOperator::CreateXor(LHS, RHS);
Expand Down Expand Up @@ -2286,6 +2310,33 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
}
}

{
Value *W, *Z;
if (match(Op0, m_AddLike(m_Value(W), m_Value(X))) &&
match(Op1, m_AddLike(m_Value(Y), m_Value(Z)))) {
Instruction *R = nullptr;
if (W == Y)
R = BinaryOperator::CreateSub(X, Z);
else if (W == Z)
R = BinaryOperator::CreateSub(X, Y);
else if (X == Y)
R = BinaryOperator::CreateSub(W, Z);
else if (X == Z)
R = BinaryOperator::CreateSub(W, Y);
if (R) {
bool NSW = I.hasNoSignedWrap() &&
match(Op0, m_NSWAddLike(m_Value(), m_Value())) &&
match(Op1, m_NSWAddLike(m_Value(), m_Value()));

bool NUW = I.hasNoUnsignedWrap() &&
match(Op1, m_NUWAddLike(m_Value(), m_Value()));
R->setHasNoSignedWrap(NSW);
R->setHasNoUnsignedWrap(NUW);
return R;
}
}
}

// (~X) - (~Y) --> Y - X
{
// Need to ensure we can consume at least one of the `not` instructions,
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3580,6 +3580,17 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
return R;

if (cast<PossiblyDisjointInst>(I).isDisjoint()) {
if (Instruction *R =
foldAddLikeCommutative(I.getOperand(0), I.getOperand(1),
/*NSW=*/true, /*NUW=*/true))
return R;
if (Instruction *R =
foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
/*NSW=*/true, /*NUW=*/true))
return R;
}

Value *X, *Y;
const APInt *CV;
if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) &&
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
FPClassTest DemandedMask, KnownFPClass &Known,
unsigned Depth = 0);

/// Common transforms for add / disjoint or
Instruction *foldAddLikeCommutative(Value *LHS, Value *RHS, bool NSW,
bool NUW);

/// Canonicalize the position of binops relative to shufflevector.
Instruction *foldVectorBinop(BinaryOperator &Inst);
Instruction *foldVectorSelect(SelectInst &Sel);
Expand Down
20 changes: 10 additions & 10 deletions llvm/test/Transforms/InstCombine/fold-add-sub.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ define i8 @test_add_nsw(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = add nsw i8 [[X]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
; CHECK-NEXT: [[R:%.*]] = sub nsw i8 [[LHS]], [[RHS]]
; CHECK-NEXT: [[R:%.*]] = sub nsw i8 [[Y]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = add nsw i8 %x, %y
Expand All @@ -25,7 +25,7 @@ define i8 @test_add_nsw_no_prop(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = add nuw i8 [[X]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
; CHECK-NEXT: [[R:%.*]] = sub nsw i8 [[LHS]], [[RHS]]
; CHECK-NEXT: [[R:%.*]] = sub i8 [[Y]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = add nsw i8 %x, %y
Expand All @@ -42,7 +42,7 @@ define i8 @test_add(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = add i8 [[X]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
; CHECK-NEXT: [[R:%.*]] = sub i8 [[LHS]], [[RHS]]
; CHECK-NEXT: [[R:%.*]] = sub i8 [[Y]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = add i8 %x, %y
Expand Down Expand Up @@ -76,7 +76,7 @@ define i8 @test_add_nuw(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = or disjoint i8 [[X]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
; CHECK-NEXT: [[R:%.*]] = sub nuw i8 [[LHS]], [[RHS]]
; CHECK-NEXT: [[R:%.*]] = sub nuw i8 [[Y]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = add i8 %x, %y
Expand All @@ -93,7 +93,7 @@ define i8 @test_add_nuw_no_prop(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = or disjoint i8 [[X]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
; CHECK-NEXT: [[R:%.*]] = sub i8 [[LHS]], [[RHS]]
; CHECK-NEXT: [[R:%.*]] = sub i8 [[Y]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = add i8 %x, %y
Expand All @@ -110,7 +110,7 @@ define i8 @test_sub_nuw(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = sub nuw i8 [[Y]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
; CHECK-NEXT: [[R:%.*]] = add i8 [[LHS]], [[RHS]]
; CHECK-NEXT: [[R:%.*]] = sub nuw i8 [[X]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = sub nuw i8 %x, %y
Expand All @@ -127,7 +127,7 @@ define i8 @test_sub_nuw_no_prop(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = sub i8 [[Y]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
; CHECK-NEXT: [[R:%.*]] = add nuw i8 [[LHS]], [[RHS]]
; CHECK-NEXT: [[R:%.*]] = sub i8 [[X]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = sub nuw i8 %x, %y
Expand All @@ -144,7 +144,7 @@ define i8 @test_sub_nsw(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = sub nsw i8 [[Y]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
; CHECK-NEXT: [[R:%.*]] = or disjoint i8 [[LHS]], [[RHS]]
; CHECK-NEXT: [[R:%.*]] = sub nsw i8 [[X]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = sub nsw i8 %x, %y
Expand All @@ -161,7 +161,7 @@ define i8 @test_sub_nsw_no_prop(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = sub nsw i8 [[Y]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
; CHECK-NEXT: [[R:%.*]] = or disjoint i8 [[LHS]], [[RHS]]
; CHECK-NEXT: [[R:%.*]] = sub i8 [[X]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = sub i8 %x, %y
Expand All @@ -178,7 +178,7 @@ define i8 @test_sub_none(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[RHS:%.*]] = sub i8 [[Y]], [[Z:%.*]]
; CHECK-NEXT: call void @use.i8(i8 [[LHS]])
; CHECK-NEXT: call void @use.i8(i8 [[RHS]])
; CHECK-NEXT: [[R:%.*]] = add i8 [[LHS]], [[RHS]]
; CHECK-NEXT: [[R:%.*]] = sub i8 [[X]], [[Z]]
; CHECK-NEXT: ret i8 [[R]]
;
%lhs = sub i8 %x, %y
Expand Down

0 comments on commit bac00b0

Please sign in to comment.