Skip to content

Commit

Permalink
Fold Select with Const Operand to BinOp
Browse files Browse the repository at this point in the history
  • Loading branch information
veera-sivarajan committed Nov 30, 2024
1 parent 77c2b79 commit f63b7be
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 131 deletions.
44 changes: 18 additions & 26 deletions clang/test/CodeGen/attr-counted-by.c
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@ void test1(struct annotated *p, int index, int val) {
// SANITIZE-WITH-ATTR: cont3:
// SANITIZE-WITH-ATTR-NEXT: [[ARRAY:%.*]] = getelementptr inbounds nuw i8, ptr [[P]], i64 12
// SANITIZE-WITH-ATTR-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw [0 x i32], ptr [[ARRAY]], i64 0, i64 [[INDEX]]
// SANITIZE-WITH-ATTR-NEXT: [[TMP2:%.*]] = shl i32 [[DOT_COUNTED_BY_LOAD]], 2
// SANITIZE-WITH-ATTR-NEXT: [[DOTINV:%.*]] = icmp slt i32 [[DOT_COUNTED_BY_LOAD]], 0
// SANITIZE-WITH-ATTR-NEXT: [[CONV:%.*]] = select i1 [[DOTINV]], i32 0, i32 [[TMP2]]
// SANITIZE-WITH-ATTR-NEXT: [[TMP2:%.*]] = tail call i32 @llvm.smax.i32(i32 [[DOT_COUNTED_BY_LOAD]], i32 0)
// SANITIZE-WITH-ATTR-NEXT: [[CONV:%.*]] = shl i32 [[TMP2]], 2
// SANITIZE-WITH-ATTR-NEXT: store i32 [[CONV]], ptr [[ARRAYIDX]], align 4, !tbaa [[TBAA4]]
// SANITIZE-WITH-ATTR-NEXT: ret void
//
Expand All @@ -130,9 +129,8 @@ void test1(struct annotated *p, int index, int val) {
// NO-SANITIZE-WITH-ATTR-NEXT: entry:
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOT_COUNTED_BY_GEP:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 8
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOT_COUNTED_BY_LOAD:%.*]] = load i32, ptr [[DOT_COUNTED_BY_GEP]], align 4
// NO-SANITIZE-WITH-ATTR-NEXT: [[TMP0:%.*]] = shl i32 [[DOT_COUNTED_BY_LOAD]], 2
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOTINV:%.*]] = icmp slt i32 [[DOT_COUNTED_BY_LOAD]], 0
// NO-SANITIZE-WITH-ATTR-NEXT: [[CONV:%.*]] = select i1 [[DOTINV]], i32 0, i32 [[TMP0]]
// NO-SANITIZE-WITH-ATTR-NEXT: [[TMP0:%.*]] = tail call i32 @llvm.smax.i32(i32 [[DOT_COUNTED_BY_LOAD]], i32 0)
// NO-SANITIZE-WITH-ATTR-NEXT: [[CONV:%.*]] = shl i32 [[TMP0]], 2
// NO-SANITIZE-WITH-ATTR-NEXT: [[ARRAY:%.*]] = getelementptr inbounds nuw i8, ptr [[P]], i64 12
// NO-SANITIZE-WITH-ATTR-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw [0 x i32], ptr [[ARRAY]], i64 0, i64 [[INDEX]]
// NO-SANITIZE-WITH-ATTR-NEXT: store i32 [[CONV]], ptr [[ARRAYIDX]], align 4, !tbaa [[TBAA2]]
Expand Down Expand Up @@ -539,10 +537,9 @@ size_t test5_bdos(struct anon_struct *p) {
// SANITIZE-WITH-ATTR: cont3:
// SANITIZE-WITH-ATTR-NEXT: [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[P]], i64 16
// SANITIZE-WITH-ATTR-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [0 x i32], ptr [[TMP1]], i64 0, i64 [[IDXPROM]]
// SANITIZE-WITH-ATTR-NEXT: [[DOTINV:%.*]] = icmp slt i64 [[DOT_COUNTED_BY_LOAD]], 0
// SANITIZE-WITH-ATTR-NEXT: [[DOT_COUNTED_BY_LOAD_TR:%.*]] = trunc i64 [[DOT_COUNTED_BY_LOAD]] to i32
// SANITIZE-WITH-ATTR-NEXT: [[TMP2:%.*]] = shl i32 [[DOT_COUNTED_BY_LOAD_TR]], 2
// SANITIZE-WITH-ATTR-NEXT: [[CONV:%.*]] = select i1 [[DOTINV]], i32 0, i32 [[TMP2]]
// SANITIZE-WITH-ATTR-NEXT: [[TMP2:%.*]] = tail call i64 @llvm.smax.i64(i64 [[DOT_COUNTED_BY_LOAD]], i64 0)
// SANITIZE-WITH-ATTR-NEXT: [[DOTTR:%.*]] = trunc i64 [[TMP2]] to i32
// SANITIZE-WITH-ATTR-NEXT: [[CONV:%.*]] = shl i32 [[DOTTR]], 2
// SANITIZE-WITH-ATTR-NEXT: store i32 [[CONV]], ptr [[ARRAYIDX]], align 4, !tbaa [[TBAA4]]
// SANITIZE-WITH-ATTR-NEXT: ret void
//
Expand All @@ -551,10 +548,9 @@ size_t test5_bdos(struct anon_struct *p) {
// NO-SANITIZE-WITH-ATTR-NEXT: entry:
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOT_COUNTED_BY_GEP:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 8
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOT_COUNTED_BY_LOAD:%.*]] = load i64, ptr [[DOT_COUNTED_BY_GEP]], align 4
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOTINV:%.*]] = icmp slt i64 [[DOT_COUNTED_BY_LOAD]], 0
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOT_COUNTED_BY_LOAD_TR:%.*]] = trunc i64 [[DOT_COUNTED_BY_LOAD]] to i32
// NO-SANITIZE-WITH-ATTR-NEXT: [[TMP0:%.*]] = shl i32 [[DOT_COUNTED_BY_LOAD_TR]], 2
// NO-SANITIZE-WITH-ATTR-NEXT: [[CONV:%.*]] = select i1 [[DOTINV]], i32 0, i32 [[TMP0]]
// NO-SANITIZE-WITH-ATTR-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.smax.i64(i64 [[DOT_COUNTED_BY_LOAD]], i64 0)
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOTTR:%.*]] = trunc i64 [[TMP0]] to i32
// NO-SANITIZE-WITH-ATTR-NEXT: [[CONV:%.*]] = shl i32 [[DOTTR]], 2
// NO-SANITIZE-WITH-ATTR-NEXT: [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[P]], i64 16
// NO-SANITIZE-WITH-ATTR-NEXT: [[IDXPROM:%.*]] = sext i32 [[INDEX]] to i64
// NO-SANITIZE-WITH-ATTR-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [0 x i32], ptr [[TMP1]], i64 0, i64 [[IDXPROM]]
Expand Down Expand Up @@ -588,19 +584,17 @@ void test6(struct anon_struct *p, int index) {
// SANITIZE-WITH-ATTR-NEXT: entry:
// SANITIZE-WITH-ATTR-NEXT: [[DOT_COUNTED_BY_GEP:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 8
// SANITIZE-WITH-ATTR-NEXT: [[DOT_COUNTED_BY_LOAD:%.*]] = load i64, ptr [[DOT_COUNTED_BY_GEP]], align 4
// SANITIZE-WITH-ATTR-NEXT: [[TMP0:%.*]] = shl nuw i64 [[DOT_COUNTED_BY_LOAD]], 2
// SANITIZE-WITH-ATTR-NEXT: [[DOTINV:%.*]] = icmp slt i64 [[DOT_COUNTED_BY_LOAD]], 0
// SANITIZE-WITH-ATTR-NEXT: [[TMP1:%.*]] = select i1 [[DOTINV]], i64 0, i64 [[TMP0]]
// SANITIZE-WITH-ATTR-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.smax.i64(i64 [[DOT_COUNTED_BY_LOAD]], i64 0)
// SANITIZE-WITH-ATTR-NEXT: [[TMP1:%.*]] = shl i64 [[TMP0]], 2
// SANITIZE-WITH-ATTR-NEXT: ret i64 [[TMP1]]
//
// NO-SANITIZE-WITH-ATTR-LABEL: define dso_local range(i64 0, -3) i64 @test6_bdos(
// NO-SANITIZE-WITH-ATTR-SAME: ptr nocapture noundef readonly [[P:%.*]]) local_unnamed_addr #[[ATTR2]] {
// NO-SANITIZE-WITH-ATTR-NEXT: entry:
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOT_COUNTED_BY_GEP:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 8
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOT_COUNTED_BY_LOAD:%.*]] = load i64, ptr [[DOT_COUNTED_BY_GEP]], align 4
// NO-SANITIZE-WITH-ATTR-NEXT: [[TMP0:%.*]] = shl nuw i64 [[DOT_COUNTED_BY_LOAD]], 2
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOTINV:%.*]] = icmp slt i64 [[DOT_COUNTED_BY_LOAD]], 0
// NO-SANITIZE-WITH-ATTR-NEXT: [[TMP1:%.*]] = select i1 [[DOTINV]], i64 0, i64 [[TMP0]]
// NO-SANITIZE-WITH-ATTR-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.smax.i64(i64 [[DOT_COUNTED_BY_LOAD]], i64 0)
// NO-SANITIZE-WITH-ATTR-NEXT: [[TMP1:%.*]] = shl i64 [[TMP0]], 2
// NO-SANITIZE-WITH-ATTR-NEXT: ret i64 [[TMP1]]
//
// SANITIZE-WITHOUT-ATTR-LABEL: define dso_local i64 @test6_bdos(
Expand Down Expand Up @@ -1740,9 +1734,8 @@ struct annotated_struct_array {
// SANITIZE-WITH-ATTR: cont20:
// SANITIZE-WITH-ATTR-NEXT: [[ARRAY:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP2]], i64 12
// SANITIZE-WITH-ATTR-NEXT: [[ARRAYIDX18:%.*]] = getelementptr inbounds [0 x i32], ptr [[ARRAY]], i64 0, i64 [[IDXPROM15]]
// SANITIZE-WITH-ATTR-NEXT: [[TMP5:%.*]] = shl i32 [[DOT_COUNTED_BY_LOAD]], 2
// SANITIZE-WITH-ATTR-NEXT: [[DOTINV:%.*]] = icmp slt i32 [[DOT_COUNTED_BY_LOAD]], 0
// SANITIZE-WITH-ATTR-NEXT: [[CONV:%.*]] = select i1 [[DOTINV]], i32 0, i32 [[TMP5]]
// SANITIZE-WITH-ATTR-NEXT: [[TMP5:%.*]] = tail call i32 @llvm.smax.i32(i32 [[DOT_COUNTED_BY_LOAD]], i32 0)
// SANITIZE-WITH-ATTR-NEXT: [[CONV:%.*]] = shl i32 [[TMP5]], 2
// SANITIZE-WITH-ATTR-NEXT: store i32 [[CONV]], ptr [[ARRAYIDX18]], align 4, !tbaa [[TBAA4]]
// SANITIZE-WITH-ATTR-NEXT: ret void
//
Expand All @@ -1754,9 +1747,8 @@ struct annotated_struct_array {
// NO-SANITIZE-WITH-ATTR-NEXT: [[TMP0:%.*]] = load ptr, ptr [[ARRAYIDX]], align 8, !tbaa [[TBAA11]]
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOT_COUNTED_BY_GEP:%.*]] = getelementptr inbounds i8, ptr [[TMP0]], i64 8
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOT_COUNTED_BY_LOAD:%.*]] = load i32, ptr [[DOT_COUNTED_BY_GEP]], align 4
// NO-SANITIZE-WITH-ATTR-NEXT: [[TMP1:%.*]] = shl i32 [[DOT_COUNTED_BY_LOAD]], 2
// NO-SANITIZE-WITH-ATTR-NEXT: [[DOTINV:%.*]] = icmp slt i32 [[DOT_COUNTED_BY_LOAD]], 0
// NO-SANITIZE-WITH-ATTR-NEXT: [[CONV:%.*]] = select i1 [[DOTINV]], i32 0, i32 [[TMP1]]
// NO-SANITIZE-WITH-ATTR-NEXT: [[TMP1:%.*]] = tail call i32 @llvm.smax.i32(i32 [[DOT_COUNTED_BY_LOAD]], i32 0)
// NO-SANITIZE-WITH-ATTR-NEXT: [[CONV:%.*]] = shl i32 [[TMP1]], 2
// NO-SANITIZE-WITH-ATTR-NEXT: [[ARRAY:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i64 12
// NO-SANITIZE-WITH-ATTR-NEXT: [[IDXPROM4:%.*]] = sext i32 [[IDX2]] to i64
// NO-SANITIZE-WITH-ATTR-NEXT: [[ARRAYIDX5:%.*]] = getelementptr inbounds [0 x i32], ptr [[ARRAY]], i64 0, i64 [[IDXPROM4]]
Expand Down
10 changes: 10 additions & 0 deletions llvm/include/llvm/Analysis/ValueTracking.h
Original file line number Diff line number Diff line change
Expand Up @@ -1178,10 +1178,20 @@ SelectPatternResult matchDecomposedSelectPattern(
CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS,
Instruction::CastOps *CastOp = nullptr, unsigned Depth = 0);

/// Determine the pattern for predicate `X Pred Y ? X : Y`.
SelectPatternResult
getSelectPattern(CmpInst::Predicate Pred,
SelectPatternNaNBehavior NaNBehavior = SPNB_NA,
bool Ordered = false);

/// Return the canonical comparison predicate for the specified
/// minimum/maximum flavor.
CmpInst::Predicate getMinMaxPred(SelectPatternFlavor SPF, bool Ordered = false);

/// Convert given `SPF` to equivalent min/max intrinsic.
/// Caller must ensure `SPF` is an integer min or max pattern.
Intrinsic::ID getMinMaxIntrinsic(SelectPatternFlavor SPF);

/// Return the inverse minimum/maximum flavor of the specified flavor.
/// For example, signed minimum is the inverse of signed maximum.
SelectPatternFlavor getInverseMinMaxFlavor(SelectPatternFlavor SPF);
Expand Down
69 changes: 48 additions & 21 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8589,6 +8589,37 @@ bool llvm::isKnownInversion(const Value *X, const Value *Y) {
return CR1.inverse() == CR2;
}

SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred,
SelectPatternNaNBehavior NaNBehavior,
bool Ordered) {
switch (Pred) {
default:
return {SPF_UNKNOWN, SPNB_NA, false}; // Equality.
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
return {SPF_UMAX, SPNB_NA, false};
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
return {SPF_SMAX, SPNB_NA, false};
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
return {SPF_UMIN, SPNB_NA, false};
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
return {SPF_SMIN, SPNB_NA, false};
case FCmpInst::FCMP_UGT:
case FCmpInst::FCMP_UGE:
case FCmpInst::FCMP_OGT:
case FCmpInst::FCMP_OGE:
return {SPF_FMAXNUM, NaNBehavior, Ordered};
case FCmpInst::FCMP_ULT:
case FCmpInst::FCMP_ULE:
case FCmpInst::FCMP_OLT:
case FCmpInst::FCMP_OLE:
return {SPF_FMINNUM, NaNBehavior, Ordered};
}
}

static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
FastMathFlags FMF,
Value *CmpLHS, Value *CmpRHS,
Expand Down Expand Up @@ -8696,27 +8727,8 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
}

// ([if]cmp X, Y) ? X : Y
if (TrueVal == CmpLHS && FalseVal == CmpRHS) {
switch (Pred) {
default: return {SPF_UNKNOWN, SPNB_NA, false}; // Equality.
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE: return {SPF_UMAX, SPNB_NA, false};
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE: return {SPF_SMAX, SPNB_NA, false};
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE: return {SPF_UMIN, SPNB_NA, false};
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE: return {SPF_SMIN, SPNB_NA, false};
case FCmpInst::FCMP_UGT:
case FCmpInst::FCMP_UGE:
case FCmpInst::FCMP_OGT:
case FCmpInst::FCMP_OGE: return {SPF_FMAXNUM, NaNBehavior, Ordered};
case FCmpInst::FCMP_ULT:
case FCmpInst::FCMP_ULE:
case FCmpInst::FCMP_OLT:
case FCmpInst::FCMP_OLE: return {SPF_FMINNUM, NaNBehavior, Ordered};
}
}
if (TrueVal == CmpLHS && FalseVal == CmpRHS)
return getSelectPattern(Pred, NaNBehavior, Ordered);

if (isKnownNegation(TrueVal, FalseVal)) {
// Sign-extending LHS does not change its sign, so TrueVal/FalseVal can
Expand Down Expand Up @@ -8960,6 +8972,21 @@ CmpInst::Predicate llvm::getMinMaxPred(SelectPatternFlavor SPF, bool Ordered) {
llvm_unreachable("unhandled!");
}

Intrinsic::ID llvm::getMinMaxIntrinsic(SelectPatternFlavor SPF) {
switch (SPF) {
case SelectPatternFlavor::SPF_UMIN:
return Intrinsic::umin;
case SelectPatternFlavor::SPF_UMAX:
return Intrinsic::umax;
case SelectPatternFlavor::SPF_SMIN:
return Intrinsic::smin;
case SelectPatternFlavor::SPF_SMAX:
return Intrinsic::smax;
default:
llvm_unreachable("Unexpected SPF");
}
}

SelectPatternFlavor llvm::getInverseMinMaxFlavor(SelectPatternFlavor SPF) {
if (SPF == SPF_SMIN) return SPF_SMAX;
if (SPF == SPF_UMIN) return SPF_UMAX;
Expand Down
75 changes: 58 additions & 17 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1257,23 +1257,7 @@ static Value *canonicalizeSPF(ICmpInst &Cmp, Value *TrueVal, Value *FalseVal,
}

if (SelectPatternResult::isMinOrMax(SPF)) {
Intrinsic::ID IntrinsicID;
switch (SPF) {
case SelectPatternFlavor::SPF_UMIN:
IntrinsicID = Intrinsic::umin;
break;
case SelectPatternFlavor::SPF_UMAX:
IntrinsicID = Intrinsic::umax;
break;
case SelectPatternFlavor::SPF_SMIN:
IntrinsicID = Intrinsic::smin;
break;
case SelectPatternFlavor::SPF_SMAX:
IntrinsicID = Intrinsic::smax;
break;
default:
llvm_unreachable("Unexpected SPF");
}
Intrinsic::ID IntrinsicID = getMinMaxIntrinsic(SPF);
return IC.Builder.CreateBinaryIntrinsic(IntrinsicID, LHS, RHS);
}

Expand Down Expand Up @@ -1898,6 +1882,60 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
return nullptr;
}

/// Fold `X Pred C1 ? X BOp C2 : C1 BOp C2` to `min/max(X, C1) BOp C2`.
/// This allows for better canonicalization.
static Value *foldSelectWithConstOpToBinOp(ICmpInst *Cmp, Value *TrueVal,
Value *FalseVal,
IRBuilderBase &Builder) {
BinaryOperator *BOp;
Constant *C1, *C2, *C3;
Value *X;
ICmpInst::Predicate Predicate;

if (!match(Cmp, m_ICmp(Predicate, m_Value(X), m_Constant(C1))))
return nullptr;

if (!ICmpInst::isRelational(Predicate))
return nullptr;

if (match(TrueVal, m_Constant())) {
std::swap(FalseVal, TrueVal);
Predicate = ICmpInst::getInversePredicate(Predicate);
}

if (!match(TrueVal, m_BinOp(BOp)) || !match(FalseVal, m_Constant(C3)))
return nullptr;

unsigned Opcode = BOp->getOpcode();

if (Instruction::isIntDivRem(Opcode))
return nullptr;

if (!match(BOp, m_OneUse(m_BinOp(m_Specific(X), m_Constant(C2)))))
return nullptr;

Value *RHS;
SelectPatternFlavor SPF;
const DataLayout &Layout = BOp->getDataLayout();
auto Flipped =
InstCombiner::getFlippedStrictnessPredicateAndConstant(Predicate, C1);

if (C3 == ConstantFoldBinaryOpOperands(Opcode, C1, C2, Layout)) {
SPF = getSelectPattern(Predicate).Flavor;
RHS = C1;
} else if (Flipped && C3 == ConstantFoldBinaryOpOperands(
Opcode, Flipped->second, C2, Layout)) {
SPF = getSelectPattern(Flipped->first).Flavor;
RHS = Flipped->second;
} else {
return nullptr;
}

Intrinsic::ID IntrinsicID = getMinMaxIntrinsic(SPF);
Value *Intrinsic = Builder.CreateBinaryIntrinsic(IntrinsicID, X, RHS);
return Builder.CreateBinOp(BOp->getOpcode(), Intrinsic, C2);
}

/// Visit a SelectInst that has an ICmpInst as its first operand.
Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
ICmpInst *ICI) {
Expand Down Expand Up @@ -1987,6 +2025,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);

if (Value *V = foldSelectWithConstOpToBinOp(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);

return Changed ? &SI : nullptr;
}

Expand Down
Loading

0 comments on commit f63b7be

Please sign in to comment.