Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IR: introduce struct with CmpInst::Predicate and samesign #116867

Merged
merged 2 commits into from
Dec 3, 2024

Conversation

artagnon
Copy link
Contributor

@artagnon artagnon commented Nov 19, 2024

Introduce llvm::CmpPredicate, an abstraction over a floating-point predicate, and a pack of an integer predicate with samesign information, in order to ease extending large portions of the codebase that take a CmpInst::Predicate to respect the samesign flag.

We have chosen to demonstrate the utility of this new abstraction by migrating parts of ValueTracking, InstructionSimplify, and InstCombine from CmpInst::Predicate to llvm::CmpPredicate. There should be no functional changes, as we don't perform any extra optimizations with samesign in this patch, or use CmpPredicate::getMatching.

The design approach taken by this patch allows for unaudited callers of APIs that take a llvm::CmpPredicate to silently drop the samesign information; it does not pose a correctness issue, and allows us to migrate the codebase piece-wise.

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-llvm-ir

Author: Ramkumar Ramachandra (artagnon)

Changes

Introduce CmpInst::PredicateSign, an abstraction over a floating-point predicate, and a pack of an integer predicate with samesign information, in order to ease extending large portions of the codebase that take a CmpInst::Predicate to respect the samesign flag.

We have chosen to demonstrate the utility of this new abstraction by migrating ValueTracking, InstructionSimplify, and InstCombine from CmpInst::Predicate to CmpInst::PredicateSign. There should be no functional changes, as we don't perform any extra optimizations with samesign in this patch.

The design approach taken by this patch allows for unaudited callers of APIs that take a CmpInst::PredicateSign to silently drop the samesign information; it does not pose a correctness issue, and allows us to migrate the codebase piece-wise.

-- 8< --
Based on #116866.


Patch is 42.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116867.diff

11 Files Affected:

  • (modified) llvm/include/llvm/Analysis/InstructionSimplify.h (+4-4)
  • (modified) llvm/include/llvm/Analysis/ValueTracking.h (+2-2)
  • (modified) llvm/include/llvm/IR/InstrTypes.h (+42-16)
  • (modified) llvm/include/llvm/IR/Instructions.h (+7-19)
  • (modified) llvm/include/llvm/Transforms/InstCombine/InstCombiner.h (+7-6)
  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+44-41)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+9-9)
  • (modified) llvm/lib/IR/Instructions.cpp (+30-46)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+18-17)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+5-5)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+3-3)
diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index cf7d3e044188a6..803050c7a0f438 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -152,12 +152,12 @@ Value *simplifyOrInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
 Value *simplifyXorInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
 
 /// Given operands for an ICmpInst, fold the result or return null.
-Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+Value *simplifyICmpInst(CmpInst::PredicateSign Pred, Value *LHS, Value *RHS,
                         const SimplifyQuery &Q);
 
 /// Given operands for an FCmpInst, fold the result or return null.
-Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
-                        FastMathFlags FMF, const SimplifyQuery &Q);
+Value *simplifyFCmpInst(CmpInst::PredicateSign Predicate, Value *LHS,
+                        Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q);
 
 /// Given operands for a SelectInst, fold the result or return null.
 Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
@@ -200,7 +200,7 @@ Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1, ArrayRef<int> Mask,
 //=== Helper functions for higher up the class hierarchy.
 
 /// Given operands for a CmpInst, fold the result or return null.
-Value *simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+Value *simplifyCmpInst(CmpInst::PredicateSign Predicate, Value *LHS, Value *RHS,
                        const SimplifyQuery &Q);
 
 /// Given operand for a UnaryOperator, fold the result or return null.
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 2b0377903ac8e3..81982b0a0a79d8 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -1246,7 +1246,7 @@ std::optional<bool> isImpliedCondition(const Value *LHS, const Value *RHS,
                                        bool LHSIsTrue = true,
                                        unsigned Depth = 0);
 std::optional<bool> isImpliedCondition(const Value *LHS,
-                                       CmpInst::Predicate RHSPred,
+                                       CmpInst::PredicateSign RHSPred,
                                        const Value *RHSOp0, const Value *RHSOp1,
                                        const DataLayout &DL,
                                        bool LHSIsTrue = true,
@@ -1257,7 +1257,7 @@ std::optional<bool> isImpliedCondition(const Value *LHS,
 std::optional<bool> isImpliedByDomCondition(const Value *Cond,
                                             const Instruction *ContextI,
                                             const DataLayout &DL);
-std::optional<bool> isImpliedByDomCondition(CmpInst::Predicate Pred,
+std::optional<bool> isImpliedByDomCondition(CmpInst::PredicateSign Pred,
                                             const Value *LHS, const Value *RHS,
                                             const Instruction *ContextI,
                                             const DataLayout &DL);
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 1c60eae7f2f85b..ebba33d1a8f8ed 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -722,6 +722,31 @@ class CmpInst : public Instruction {
                               force_iteration_on_noniterable_enum);
   }
 
+  /// An abstraction over a floating-point predicate, and a pack of an integer
+  /// predicate with samesign information. The getPredicateSign() family of
+  /// functions in ICmpInst construct and return this type. It is also implictly
+  /// constructed with a Predicate, dropping samesign information.
+  class PredicateSign {
+    Predicate Pred;
+    std::optional<bool> HasSameSign;
+
+  public:
+    PredicateSign(Predicate Pred, bool HasSameSign)
+        : Pred(Pred), HasSameSign(HasSameSign) {}
+
+    PredicateSign(Predicate Pred) : Pred(Pred) {
+      if (isIntPredicate(Pred))
+        HasSameSign = false;
+    }
+
+    operator Predicate() { return Pred; }
+
+    bool hasSameSign() {
+      assert(isIntPredicate(Pred) && HasSameSign);
+      return *HasSameSign;
+    }
+  };
+
 protected:
   CmpInst(Type *ty, Instruction::OtherOps op, Predicate pred, Value *LHS,
           Value *RHS, const Twine &Name = "",
@@ -935,28 +960,29 @@ class CmpInst : public Instruction {
     return isUnsigned(getPredicate());
   }
 
-  /// For example, ULT->SLT, ULE->SLE, UGT->SGT, UGE->SGE, SLT->Failed assert
-  /// @returns the signed version of the unsigned predicate pred.
-  /// return the signed version of a predicate
+  /// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
+  /// @returns the predicate that would be the result if the operand were
+  /// regarded as signed. Asserts on FP predicates.
+  /// Static variant.
   static Predicate getSignedPredicate(Predicate pred);
 
-  /// For example, ULT->SLT, ULE->SLE, UGT->SGT, UGE->SGE, SLT->Failed assert
-  /// @returns the signed version of the predicate for this instruction (which
-  /// has to be an unsigned predicate).
-  /// return the signed version of a predicate
-  Predicate getSignedPredicate() {
+  /// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
+  /// @returns the predicate that would be the result if the operand were
+  /// regarded as signed. Asserts on FP predicates.
+  Predicate getSignedPredicate() const {
     return getSignedPredicate(getPredicate());
   }
 
-  /// For example, SLT->ULT, SLE->ULE, SGT->UGT, SGE->UGE, ULT->Failed assert
-  /// @returns the unsigned version of the signed predicate pred.
+  /// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc.
+  /// @returns the predicate that would be the result if the operand were
+  /// regarded as unsigned. Asserts on FP predicates.
+  /// Static variant.
   static Predicate getUnsignedPredicate(Predicate pred);
 
-  /// For example, SLT->ULT, SLE->ULE, SGT->UGT, SGE->UGE, ULT->Failed assert
-  /// @returns the unsigned version of the predicate for this instruction (which
-  /// has to be an signed predicate).
-  /// return the unsigned version of a predicate
-  Predicate getUnsignedPredicate() {
+  /// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc.
+  /// @returns the predicate that would be the result if the operand were
+  /// regarded as unsigned. Asserts on FP predicates.
+  Predicate getUnsignedPredicate() const {
     return getUnsignedPredicate(getPredicate());
   }
 
@@ -968,7 +994,7 @@ class CmpInst : public Instruction {
   /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert
   /// @returns the unsigned version of the signed predicate pred or
   ///          the signed version of the signed predicate pred.
-  Predicate getFlippedSignednessPredicate() {
+  Predicate getFlippedSignednessPredicate() const {
     return getFlippedSignednessPredicate(getPredicate());
   }
 
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index 8eea659a00caf3..8b1d9a1aa17d82 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1203,29 +1203,17 @@ class ICmpInst: public CmpInst {
 #endif
   }
 
-  /// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
-  /// @returns the predicate that would be the result if the operand were
-  /// regarded as signed.
-  /// Return the signed version of the predicate
-  Predicate getSignedPredicate() const {
-    return getSignedPredicate(getPredicate());
+  PredicateSign getPredicateSign() const {
+    return {getPredicate(), hasSameSign()};
   }
 
-  /// This is a static version that you can use without an instruction.
-  /// Return the signed version of the predicate.
-  static Predicate getSignedPredicate(Predicate pred);
-
-  /// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc.
-  /// @returns the predicate that would be the result if the operand were
-  /// regarded as unsigned.
-  /// Return the unsigned version of the predicate
-  Predicate getUnsignedPredicate() const {
-    return getUnsignedPredicate(getPredicate());
+  PredicateSign getInversePredicateSign() const {
+    return {getInversePredicate(), hasSameSign()};
   }
 
-  /// This is a static version that you can use without an instruction.
-  /// Return the unsigned version of the predicate.
-  static Predicate getUnsignedPredicate(Predicate pred);
+  PredicateSign getSwappedPredicateSign() const {
+    return {getSwappedPredicate(), hasSameSign()};
+  }
 
   void setSameSign(bool B = true) {
     SubclassOptionalData = (SubclassOptionalData & ~SameSign) | (B * SameSign);
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 3075b7ebae59e6..850cf431d6e2e5 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -157,7 +157,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   /// conditional branch or select to create a compare with a canonical
   /// (inverted) predicate which is then more likely to be matched with other
   /// values.
-  static bool isCanonicalPredicate(CmpInst::Predicate Pred) {
+  static bool isCanonicalPredicate(CmpInst::PredicateSign Pred) {
     switch (Pred) {
     case CmpInst::ICMP_NE:
     case CmpInst::ICMP_ULE:
@@ -185,11 +185,12 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   }
 
   std::optional<std::pair<
-      CmpInst::Predicate,
-      Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpInst::
-                                                                       Predicate
-                                                                           Pred,
-                                                                   Constant *C);
+      CmpInst::PredicateSign,
+      Constant
+          *>> static getFlippedStrictnessPredicateAndConstant(CmpInst::
+                                                                  PredicateSign
+                                                                      Pred,
+                                                              Constant *C);
 
   static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
     // a ? b : false and a ? true : b are the canonical form of logical and/or.
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 93b601b22c3a39..a375e0202676db 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -63,10 +63,11 @@ static Value *simplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &,
                             unsigned);
 static Value *simplifyBinOp(unsigned, Value *, Value *, const FastMathFlags &,
                             const SimplifyQuery &, unsigned);
-static Value *simplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &,
-                              unsigned);
-static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
-                               const SimplifyQuery &Q, unsigned MaxRecurse);
+static Value *simplifyCmpInst(CmpInst::PredicateSign, Value *, Value *,
+                              const SimplifyQuery &, unsigned);
+static Value *simplifyICmpInst(CmpInst::PredicateSign Predicate, Value *LHS,
+                               Value *RHS, const SimplifyQuery &Q,
+                               unsigned MaxRecurse);
 static Value *simplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned);
 static Value *simplifyXorInst(Value *, Value *, const SimplifyQuery &,
                               unsigned);
@@ -132,7 +133,7 @@ static Constant *getFalse(Type *Ty) { return ConstantInt::getFalse(Ty); }
 static Constant *getTrue(Type *Ty) { return ConstantInt::getTrue(Ty); }
 
 /// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"?
-static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS,
+static bool isSameCompare(Value *V, CmpInst::PredicateSign Pred, Value *LHS,
                           Value *RHS) {
   CmpInst *Cmp = dyn_cast<CmpInst>(V);
   if (!Cmp)
@@ -150,7 +151,7 @@ static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS,
 ///  %cmp = icmp sle i32 %sel, %rhs
 /// Compose new comparison by substituting %sel with either %tv or %fv
 /// and see if it simplifies.
-static Value *simplifyCmpSelCase(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyCmpSelCase(CmpInst::PredicateSign Pred, Value *LHS,
                                  Value *RHS, Value *Cond,
                                  const SimplifyQuery &Q, unsigned MaxRecurse,
                                  Constant *TrueOrFalse) {
@@ -167,7 +168,7 @@ static Value *simplifyCmpSelCase(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Simplify comparison with true branch of select
-static Value *simplifyCmpSelTrueCase(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyCmpSelTrueCase(CmpInst::PredicateSign Pred, Value *LHS,
                                      Value *RHS, Value *Cond,
                                      const SimplifyQuery &Q,
                                      unsigned MaxRecurse) {
@@ -176,7 +177,7 @@ static Value *simplifyCmpSelTrueCase(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Simplify comparison with false branch of select
-static Value *simplifyCmpSelFalseCase(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyCmpSelFalseCase(CmpInst::PredicateSign Pred, Value *LHS,
                                       Value *RHS, Value *Cond,
                                       const SimplifyQuery &Q,
                                       unsigned MaxRecurse) {
@@ -471,7 +472,7 @@ static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS,
 /// We can simplify %cmp1 to true, because both branches of select are
 /// less than 3. We compose new comparison by substituting %tmp with both
 /// branches of select and see if it can be simplified.
-static Value *threadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS,
+static Value *threadCmpOverSelect(CmpInst::PredicateSign Pred, Value *LHS,
                                   Value *RHS, const SimplifyQuery &Q,
                                   unsigned MaxRecurse) {
   // Recursion is always used, so bail out at once if we already hit the limit.
@@ -564,8 +565,9 @@ static Value *threadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS,
 /// comparison by seeing whether comparing with all of the incoming phi values
 /// yields the same result every time. If so returns the common result,
 /// otherwise returns null.
-static Value *threadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS,
-                               const SimplifyQuery &Q, unsigned MaxRecurse) {
+static Value *threadCmpOverPHI(CmpInst::PredicateSign Pred, Value *LHS,
+                               Value *RHS, const SimplifyQuery &Q,
+                               unsigned MaxRecurse) {
   // Recursion is always used, so bail out at once if we already hit the limit.
   if (!MaxRecurse--)
     return nullptr;
@@ -1001,7 +1003,7 @@ Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
 /// Given a predicate and two operands, return true if the comparison is true.
 /// This is a helper for div/rem simplification where we return some other value
 /// when we can prove a relationship between the operands.
-static bool isICmpTrue(ICmpInst::Predicate Pred, Value *LHS, Value *RHS,
+static bool isICmpTrue(ICmpInst::PredicateSign Pred, Value *LHS, Value *RHS,
                        const SimplifyQuery &Q, unsigned MaxRecurse) {
   Value *V = simplifyICmpInst(Pred, LHS, RHS, Q, MaxRecurse);
   Constant *C = dyn_cast_or_null<Constant>(V);
@@ -2597,7 +2599,7 @@ static Type *getCompareTy(Value *Op) {
 /// Rummage around inside V looking for something equivalent to the comparison
 /// "LHS Pred RHS". Return such a value if found, otherwise return null.
 /// Helper function for analyzing max/min idioms.
-static Value *extractEquivalentCondition(Value *V, CmpInst::Predicate Pred,
+static Value *extractEquivalentCondition(Value *V, CmpInst::PredicateSign Pred,
                                          Value *LHS, Value *RHS) {
   SelectInst *SI = dyn_cast<SelectInst>(V);
   if (!SI)
@@ -2706,7 +2708,7 @@ static bool haveNonOverlappingStorage(const Value *V1, const Value *V2) {
 // If the C and C++ standards are ever made sufficiently restrictive in this
 // area, it may be possible to update LLVM's semantics accordingly and reinstate
 // this optimization.
-static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS,
+static Constant *computePointerICmp(CmpInst::PredicateSign Pred, Value *LHS,
                                     Value *RHS, const SimplifyQuery &Q) {
   assert(LHS->getType() == RHS->getType() && "Must have same types");
   const DataLayout &DL = Q.DL;
@@ -2855,7 +2857,7 @@ static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Fold an icmp when its operands have i1 scalar type.
-static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpOfBools(CmpInst::PredicateSign Pred, Value *LHS,
                                   Value *RHS, const SimplifyQuery &Q) {
   Type *ITy = getCompareTy(LHS); // The return type.
   Type *OpTy = LHS->getType();   // The operand type.
@@ -2958,7 +2960,7 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Try hard to fold icmp with zero RHS because this is a common case.
-static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithZero(CmpInst::PredicateSign Pred, Value *LHS,
                                    Value *RHS, const SimplifyQuery &Q) {
   if (!match(RHS, m_Zero()))
     return nullptr;
@@ -3018,7 +3020,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
   return nullptr;
 }
 
-static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithConstant(CmpInst::PredicateSign Pred, Value *LHS,
                                        Value *RHS, const InstrInfoQuery &IIQ) {
   Type *ITy = getCompareTy(RHS); // The return type.
 
@@ -3066,7 +3068,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
   return nullptr;
 }
 
-static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
+static Value *simplifyICmpWithBinOpOnLHS(CmpInst::PredicateSign Pred,
                                          BinaryOperator *LBO, Value *RHS,
                                          const SimplifyQuery &Q,
                                          unsigned MaxRecurse) {
@@ -3223,7 +3225,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
 // *) C1 < C2 && C1 >= 0, or
 // *) C2 < C1 && C1 <= 0.
 //
-static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS,
+static bool trySimplifyICmpWithAdds(CmpInst::PredicateSign Pred, Value *LHS,
                                     Value *RHS, const InstrInfoQuery &IIQ) {
   // TODO: only support icmp slt for now.
   if (Pred != CmpInst::ICMP_SLT || !IIQ.UseInstrInfo)
@@ -3248,7 +3250,7 @@ static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS,
 /// TODO: A large part of this logic is duplicated in InstCombine's
 /// foldICmpBinOp(). We should be able to share that and avoid the code
 /// duplication.
-static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithBinOp(CmpInst::PredicateSign Pred, Value *LHS,
                                     Value *RHS, const SimplifyQuery &Q,
                                     unsigned MaxRecurse) {
   BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS);
@@ -3482,7 +3484,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
 
 /// simplify integer comparisons where at least one operand of the compare
 /// matches an integer min/max idiom.
-static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithMinMax(CmpInst::PredicateSign Pred, Value *LHS,
                                      Value *RHS, const SimplifyQuery &Q,
                                      unsigned MaxRecurse) {
   Type *ITy = getCompareTy(LHS); // The return type.
@@ -...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Ramkumar Ramachandra (artagnon)

Changes

Introduce CmpInst::PredicateSign, an abstraction over a floating-point predicate, and a pack of an integer predicate with samesign information, in order to ease extending large portions of the codebase that take a CmpInst::Predicate to respect the samesign flag.

We have chosen to demonstrate the utility of this new abstraction by migrating ValueTracking, InstructionSimplify, and InstCombine from CmpInst::Predicate to CmpInst::PredicateSign. There should be no functional changes, as we don't perform any extra optimizations with samesign in this patch.

The design approach taken by this patch allows for unaudited callers of APIs that take a CmpInst::PredicateSign to silently drop the samesign information; it does not pose a correctness issue, and allows us to migrate the codebase piece-wise.

-- 8< --
Based on #116866.


Patch is 42.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116867.diff

11 Files Affected:

  • (modified) llvm/include/llvm/Analysis/InstructionSimplify.h (+4-4)
  • (modified) llvm/include/llvm/Analysis/ValueTracking.h (+2-2)
  • (modified) llvm/include/llvm/IR/InstrTypes.h (+42-16)
  • (modified) llvm/include/llvm/IR/Instructions.h (+7-19)
  • (modified) llvm/include/llvm/Transforms/InstCombine/InstCombiner.h (+7-6)
  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+44-41)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+9-9)
  • (modified) llvm/lib/IR/Instructions.cpp (+30-46)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+18-17)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+5-5)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+3-3)
diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index cf7d3e044188a6..803050c7a0f438 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -152,12 +152,12 @@ Value *simplifyOrInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
 Value *simplifyXorInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
 
 /// Given operands for an ICmpInst, fold the result or return null.
-Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+Value *simplifyICmpInst(CmpInst::PredicateSign Pred, Value *LHS, Value *RHS,
                         const SimplifyQuery &Q);
 
 /// Given operands for an FCmpInst, fold the result or return null.
-Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
-                        FastMathFlags FMF, const SimplifyQuery &Q);
+Value *simplifyFCmpInst(CmpInst::PredicateSign Predicate, Value *LHS,
+                        Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q);
 
 /// Given operands for a SelectInst, fold the result or return null.
 Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
@@ -200,7 +200,7 @@ Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1, ArrayRef<int> Mask,
 //=== Helper functions for higher up the class hierarchy.
 
 /// Given operands for a CmpInst, fold the result or return null.
-Value *simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+Value *simplifyCmpInst(CmpInst::PredicateSign Predicate, Value *LHS, Value *RHS,
                        const SimplifyQuery &Q);
 
 /// Given operand for a UnaryOperator, fold the result or return null.
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 2b0377903ac8e3..81982b0a0a79d8 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -1246,7 +1246,7 @@ std::optional<bool> isImpliedCondition(const Value *LHS, const Value *RHS,
                                        bool LHSIsTrue = true,
                                        unsigned Depth = 0);
 std::optional<bool> isImpliedCondition(const Value *LHS,
-                                       CmpInst::Predicate RHSPred,
+                                       CmpInst::PredicateSign RHSPred,
                                        const Value *RHSOp0, const Value *RHSOp1,
                                        const DataLayout &DL,
                                        bool LHSIsTrue = true,
@@ -1257,7 +1257,7 @@ std::optional<bool> isImpliedCondition(const Value *LHS,
 std::optional<bool> isImpliedByDomCondition(const Value *Cond,
                                             const Instruction *ContextI,
                                             const DataLayout &DL);
-std::optional<bool> isImpliedByDomCondition(CmpInst::Predicate Pred,
+std::optional<bool> isImpliedByDomCondition(CmpInst::PredicateSign Pred,
                                             const Value *LHS, const Value *RHS,
                                             const Instruction *ContextI,
                                             const DataLayout &DL);
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 1c60eae7f2f85b..ebba33d1a8f8ed 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -722,6 +722,31 @@ class CmpInst : public Instruction {
                               force_iteration_on_noniterable_enum);
   }
 
+  /// An abstraction over a floating-point predicate, and a pack of an integer
+  /// predicate with samesign information. The getPredicateSign() family of
+  /// functions in ICmpInst construct and return this type. It is also implictly
+  /// constructed with a Predicate, dropping samesign information.
+  class PredicateSign {
+    Predicate Pred;
+    std::optional<bool> HasSameSign;
+
+  public:
+    PredicateSign(Predicate Pred, bool HasSameSign)
+        : Pred(Pred), HasSameSign(HasSameSign) {}
+
+    PredicateSign(Predicate Pred) : Pred(Pred) {
+      if (isIntPredicate(Pred))
+        HasSameSign = false;
+    }
+
+    operator Predicate() { return Pred; }
+
+    bool hasSameSign() {
+      assert(isIntPredicate(Pred) && HasSameSign);
+      return *HasSameSign;
+    }
+  };
+
 protected:
   CmpInst(Type *ty, Instruction::OtherOps op, Predicate pred, Value *LHS,
           Value *RHS, const Twine &Name = "",
@@ -935,28 +960,29 @@ class CmpInst : public Instruction {
     return isUnsigned(getPredicate());
   }
 
-  /// For example, ULT->SLT, ULE->SLE, UGT->SGT, UGE->SGE, SLT->Failed assert
-  /// @returns the signed version of the unsigned predicate pred.
-  /// return the signed version of a predicate
+  /// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
+  /// @returns the predicate that would be the result if the operand were
+  /// regarded as signed. Asserts on FP predicates.
+  /// Static variant.
   static Predicate getSignedPredicate(Predicate pred);
 
-  /// For example, ULT->SLT, ULE->SLE, UGT->SGT, UGE->SGE, SLT->Failed assert
-  /// @returns the signed version of the predicate for this instruction (which
-  /// has to be an unsigned predicate).
-  /// return the signed version of a predicate
-  Predicate getSignedPredicate() {
+  /// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
+  /// @returns the predicate that would be the result if the operand were
+  /// regarded as signed. Asserts on FP predicates.
+  Predicate getSignedPredicate() const {
     return getSignedPredicate(getPredicate());
   }
 
-  /// For example, SLT->ULT, SLE->ULE, SGT->UGT, SGE->UGE, ULT->Failed assert
-  /// @returns the unsigned version of the signed predicate pred.
+  /// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc.
+  /// @returns the predicate that would be the result if the operand were
+  /// regarded as unsigned. Asserts on FP predicates.
+  /// Static variant.
   static Predicate getUnsignedPredicate(Predicate pred);
 
-  /// For example, SLT->ULT, SLE->ULE, SGT->UGT, SGE->UGE, ULT->Failed assert
-  /// @returns the unsigned version of the predicate for this instruction (which
-  /// has to be an signed predicate).
-  /// return the unsigned version of a predicate
-  Predicate getUnsignedPredicate() {
+  /// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc.
+  /// @returns the predicate that would be the result if the operand were
+  /// regarded as unsigned. Asserts on FP predicates.
+  Predicate getUnsignedPredicate() const {
     return getUnsignedPredicate(getPredicate());
   }
 
@@ -968,7 +994,7 @@ class CmpInst : public Instruction {
   /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert
   /// @returns the unsigned version of the signed predicate pred or
   ///          the signed version of the signed predicate pred.
-  Predicate getFlippedSignednessPredicate() {
+  Predicate getFlippedSignednessPredicate() const {
     return getFlippedSignednessPredicate(getPredicate());
   }
 
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index 8eea659a00caf3..8b1d9a1aa17d82 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1203,29 +1203,17 @@ class ICmpInst: public CmpInst {
 #endif
   }
 
-  /// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
-  /// @returns the predicate that would be the result if the operand were
-  /// regarded as signed.
-  /// Return the signed version of the predicate
-  Predicate getSignedPredicate() const {
-    return getSignedPredicate(getPredicate());
+  PredicateSign getPredicateSign() const {
+    return {getPredicate(), hasSameSign()};
   }
 
-  /// This is a static version that you can use without an instruction.
-  /// Return the signed version of the predicate.
-  static Predicate getSignedPredicate(Predicate pred);
-
-  /// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc.
-  /// @returns the predicate that would be the result if the operand were
-  /// regarded as unsigned.
-  /// Return the unsigned version of the predicate
-  Predicate getUnsignedPredicate() const {
-    return getUnsignedPredicate(getPredicate());
+  PredicateSign getInversePredicateSign() const {
+    return {getInversePredicate(), hasSameSign()};
   }
 
-  /// This is a static version that you can use without an instruction.
-  /// Return the unsigned version of the predicate.
-  static Predicate getUnsignedPredicate(Predicate pred);
+  PredicateSign getSwappedPredicateSign() const {
+    return {getSwappedPredicate(), hasSameSign()};
+  }
 
   void setSameSign(bool B = true) {
     SubclassOptionalData = (SubclassOptionalData & ~SameSign) | (B * SameSign);
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 3075b7ebae59e6..850cf431d6e2e5 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -157,7 +157,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   /// conditional branch or select to create a compare with a canonical
   /// (inverted) predicate which is then more likely to be matched with other
   /// values.
-  static bool isCanonicalPredicate(CmpInst::Predicate Pred) {
+  static bool isCanonicalPredicate(CmpInst::PredicateSign Pred) {
     switch (Pred) {
     case CmpInst::ICMP_NE:
     case CmpInst::ICMP_ULE:
@@ -185,11 +185,12 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   }
 
   std::optional<std::pair<
-      CmpInst::Predicate,
-      Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpInst::
-                                                                       Predicate
-                                                                           Pred,
-                                                                   Constant *C);
+      CmpInst::PredicateSign,
+      Constant
+          *>> static getFlippedStrictnessPredicateAndConstant(CmpInst::
+                                                                  PredicateSign
+                                                                      Pred,
+                                                              Constant *C);
 
   static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
     // a ? b : false and a ? true : b are the canonical form of logical and/or.
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 93b601b22c3a39..a375e0202676db 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -63,10 +63,11 @@ static Value *simplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &,
                             unsigned);
 static Value *simplifyBinOp(unsigned, Value *, Value *, const FastMathFlags &,
                             const SimplifyQuery &, unsigned);
-static Value *simplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &,
-                              unsigned);
-static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
-                               const SimplifyQuery &Q, unsigned MaxRecurse);
+static Value *simplifyCmpInst(CmpInst::PredicateSign, Value *, Value *,
+                              const SimplifyQuery &, unsigned);
+static Value *simplifyICmpInst(CmpInst::PredicateSign Predicate, Value *LHS,
+                               Value *RHS, const SimplifyQuery &Q,
+                               unsigned MaxRecurse);
 static Value *simplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned);
 static Value *simplifyXorInst(Value *, Value *, const SimplifyQuery &,
                               unsigned);
@@ -132,7 +133,7 @@ static Constant *getFalse(Type *Ty) { return ConstantInt::getFalse(Ty); }
 static Constant *getTrue(Type *Ty) { return ConstantInt::getTrue(Ty); }
 
 /// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"?
-static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS,
+static bool isSameCompare(Value *V, CmpInst::PredicateSign Pred, Value *LHS,
                           Value *RHS) {
   CmpInst *Cmp = dyn_cast<CmpInst>(V);
   if (!Cmp)
@@ -150,7 +151,7 @@ static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS,
 ///  %cmp = icmp sle i32 %sel, %rhs
 /// Compose new comparison by substituting %sel with either %tv or %fv
 /// and see if it simplifies.
-static Value *simplifyCmpSelCase(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyCmpSelCase(CmpInst::PredicateSign Pred, Value *LHS,
                                  Value *RHS, Value *Cond,
                                  const SimplifyQuery &Q, unsigned MaxRecurse,
                                  Constant *TrueOrFalse) {
@@ -167,7 +168,7 @@ static Value *simplifyCmpSelCase(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Simplify comparison with true branch of select
-static Value *simplifyCmpSelTrueCase(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyCmpSelTrueCase(CmpInst::PredicateSign Pred, Value *LHS,
                                      Value *RHS, Value *Cond,
                                      const SimplifyQuery &Q,
                                      unsigned MaxRecurse) {
@@ -176,7 +177,7 @@ static Value *simplifyCmpSelTrueCase(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Simplify comparison with false branch of select
-static Value *simplifyCmpSelFalseCase(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyCmpSelFalseCase(CmpInst::PredicateSign Pred, Value *LHS,
                                       Value *RHS, Value *Cond,
                                       const SimplifyQuery &Q,
                                       unsigned MaxRecurse) {
@@ -471,7 +472,7 @@ static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS,
 /// We can simplify %cmp1 to true, because both branches of select are
 /// less than 3. We compose new comparison by substituting %tmp with both
 /// branches of select and see if it can be simplified.
-static Value *threadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS,
+static Value *threadCmpOverSelect(CmpInst::PredicateSign Pred, Value *LHS,
                                   Value *RHS, const SimplifyQuery &Q,
                                   unsigned MaxRecurse) {
   // Recursion is always used, so bail out at once if we already hit the limit.
@@ -564,8 +565,9 @@ static Value *threadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS,
 /// comparison by seeing whether comparing with all of the incoming phi values
 /// yields the same result every time. If so returns the common result,
 /// otherwise returns null.
-static Value *threadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS,
-                               const SimplifyQuery &Q, unsigned MaxRecurse) {
+static Value *threadCmpOverPHI(CmpInst::PredicateSign Pred, Value *LHS,
+                               Value *RHS, const SimplifyQuery &Q,
+                               unsigned MaxRecurse) {
   // Recursion is always used, so bail out at once if we already hit the limit.
   if (!MaxRecurse--)
     return nullptr;
@@ -1001,7 +1003,7 @@ Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
 /// Given a predicate and two operands, return true if the comparison is true.
 /// This is a helper for div/rem simplification where we return some other value
 /// when we can prove a relationship between the operands.
-static bool isICmpTrue(ICmpInst::Predicate Pred, Value *LHS, Value *RHS,
+static bool isICmpTrue(ICmpInst::PredicateSign Pred, Value *LHS, Value *RHS,
                        const SimplifyQuery &Q, unsigned MaxRecurse) {
   Value *V = simplifyICmpInst(Pred, LHS, RHS, Q, MaxRecurse);
   Constant *C = dyn_cast_or_null<Constant>(V);
@@ -2597,7 +2599,7 @@ static Type *getCompareTy(Value *Op) {
 /// Rummage around inside V looking for something equivalent to the comparison
 /// "LHS Pred RHS". Return such a value if found, otherwise return null.
 /// Helper function for analyzing max/min idioms.
-static Value *extractEquivalentCondition(Value *V, CmpInst::Predicate Pred,
+static Value *extractEquivalentCondition(Value *V, CmpInst::PredicateSign Pred,
                                          Value *LHS, Value *RHS) {
   SelectInst *SI = dyn_cast<SelectInst>(V);
   if (!SI)
@@ -2706,7 +2708,7 @@ static bool haveNonOverlappingStorage(const Value *V1, const Value *V2) {
 // If the C and C++ standards are ever made sufficiently restrictive in this
 // area, it may be possible to update LLVM's semantics accordingly and reinstate
 // this optimization.
-static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS,
+static Constant *computePointerICmp(CmpInst::PredicateSign Pred, Value *LHS,
                                     Value *RHS, const SimplifyQuery &Q) {
   assert(LHS->getType() == RHS->getType() && "Must have same types");
   const DataLayout &DL = Q.DL;
@@ -2855,7 +2857,7 @@ static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Fold an icmp when its operands have i1 scalar type.
-static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpOfBools(CmpInst::PredicateSign Pred, Value *LHS,
                                   Value *RHS, const SimplifyQuery &Q) {
   Type *ITy = getCompareTy(LHS); // The return type.
   Type *OpTy = LHS->getType();   // The operand type.
@@ -2958,7 +2960,7 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Try hard to fold icmp with zero RHS because this is a common case.
-static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithZero(CmpInst::PredicateSign Pred, Value *LHS,
                                    Value *RHS, const SimplifyQuery &Q) {
   if (!match(RHS, m_Zero()))
     return nullptr;
@@ -3018,7 +3020,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
   return nullptr;
 }
 
-static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithConstant(CmpInst::PredicateSign Pred, Value *LHS,
                                        Value *RHS, const InstrInfoQuery &IIQ) {
   Type *ITy = getCompareTy(RHS); // The return type.
 
@@ -3066,7 +3068,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
   return nullptr;
 }
 
-static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
+static Value *simplifyICmpWithBinOpOnLHS(CmpInst::PredicateSign Pred,
                                          BinaryOperator *LBO, Value *RHS,
                                          const SimplifyQuery &Q,
                                          unsigned MaxRecurse) {
@@ -3223,7 +3225,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
 // *) C1 < C2 && C1 >= 0, or
 // *) C2 < C1 && C1 <= 0.
 //
-static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS,
+static bool trySimplifyICmpWithAdds(CmpInst::PredicateSign Pred, Value *LHS,
                                     Value *RHS, const InstrInfoQuery &IIQ) {
   // TODO: only support icmp slt for now.
   if (Pred != CmpInst::ICMP_SLT || !IIQ.UseInstrInfo)
@@ -3248,7 +3250,7 @@ static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS,
 /// TODO: A large part of this logic is duplicated in InstCombine's
 /// foldICmpBinOp(). We should be able to share that and avoid the code
 /// duplication.
-static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithBinOp(CmpInst::PredicateSign Pred, Value *LHS,
                                     Value *RHS, const SimplifyQuery &Q,
                                     unsigned MaxRecurse) {
   BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS);
@@ -3482,7 +3484,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
 
 /// simplify integer comparisons where at least one operand of the compare
 /// matches an integer min/max idiom.
-static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithMinMax(CmpInst::PredicateSign Pred, Value *LHS,
                                      Value *RHS, const SimplifyQuery &Q,
                                      unsigned MaxRecurse) {
   Type *ITy = getCompareTy(LHS); // The return type.
@@ -...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Ramkumar Ramachandra (artagnon)

Changes

Introduce CmpInst::PredicateSign, an abstraction over a floating-point predicate, and a pack of an integer predicate with samesign information, in order to ease extending large portions of the codebase that take a CmpInst::Predicate to respect the samesign flag.

We have chosen to demonstrate the utility of this new abstraction by migrating ValueTracking, InstructionSimplify, and InstCombine from CmpInst::Predicate to CmpInst::PredicateSign. There should be no functional changes, as we don't perform any extra optimizations with samesign in this patch.

The design approach taken by this patch allows for unaudited callers of APIs that take a CmpInst::PredicateSign to silently drop the samesign information; it does not pose a correctness issue, and allows us to migrate the codebase piece-wise.

-- 8< --
Based on #116866.


Patch is 42.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116867.diff

11 Files Affected:

  • (modified) llvm/include/llvm/Analysis/InstructionSimplify.h (+4-4)
  • (modified) llvm/include/llvm/Analysis/ValueTracking.h (+2-2)
  • (modified) llvm/include/llvm/IR/InstrTypes.h (+42-16)
  • (modified) llvm/include/llvm/IR/Instructions.h (+7-19)
  • (modified) llvm/include/llvm/Transforms/InstCombine/InstCombiner.h (+7-6)
  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+44-41)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+9-9)
  • (modified) llvm/lib/IR/Instructions.cpp (+30-46)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+18-17)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+5-5)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+3-3)
diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index cf7d3e044188a6..803050c7a0f438 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -152,12 +152,12 @@ Value *simplifyOrInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
 Value *simplifyXorInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
 
 /// Given operands for an ICmpInst, fold the result or return null.
-Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+Value *simplifyICmpInst(CmpInst::PredicateSign Pred, Value *LHS, Value *RHS,
                         const SimplifyQuery &Q);
 
 /// Given operands for an FCmpInst, fold the result or return null.
-Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
-                        FastMathFlags FMF, const SimplifyQuery &Q);
+Value *simplifyFCmpInst(CmpInst::PredicateSign Predicate, Value *LHS,
+                        Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q);
 
 /// Given operands for a SelectInst, fold the result or return null.
 Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
@@ -200,7 +200,7 @@ Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1, ArrayRef<int> Mask,
 //=== Helper functions for higher up the class hierarchy.
 
 /// Given operands for a CmpInst, fold the result or return null.
-Value *simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+Value *simplifyCmpInst(CmpInst::PredicateSign Predicate, Value *LHS, Value *RHS,
                        const SimplifyQuery &Q);
 
 /// Given operand for a UnaryOperator, fold the result or return null.
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 2b0377903ac8e3..81982b0a0a79d8 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -1246,7 +1246,7 @@ std::optional<bool> isImpliedCondition(const Value *LHS, const Value *RHS,
                                        bool LHSIsTrue = true,
                                        unsigned Depth = 0);
 std::optional<bool> isImpliedCondition(const Value *LHS,
-                                       CmpInst::Predicate RHSPred,
+                                       CmpInst::PredicateSign RHSPred,
                                        const Value *RHSOp0, const Value *RHSOp1,
                                        const DataLayout &DL,
                                        bool LHSIsTrue = true,
@@ -1257,7 +1257,7 @@ std::optional<bool> isImpliedCondition(const Value *LHS,
 std::optional<bool> isImpliedByDomCondition(const Value *Cond,
                                             const Instruction *ContextI,
                                             const DataLayout &DL);
-std::optional<bool> isImpliedByDomCondition(CmpInst::Predicate Pred,
+std::optional<bool> isImpliedByDomCondition(CmpInst::PredicateSign Pred,
                                             const Value *LHS, const Value *RHS,
                                             const Instruction *ContextI,
                                             const DataLayout &DL);
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 1c60eae7f2f85b..ebba33d1a8f8ed 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -722,6 +722,31 @@ class CmpInst : public Instruction {
                               force_iteration_on_noniterable_enum);
   }
 
+  /// An abstraction over a floating-point predicate, and a pack of an integer
+  /// predicate with samesign information. The getPredicateSign() family of
+  /// functions in ICmpInst construct and return this type. It is also implictly
+  /// constructed with a Predicate, dropping samesign information.
+  class PredicateSign {
+    Predicate Pred;
+    std::optional<bool> HasSameSign;
+
+  public:
+    PredicateSign(Predicate Pred, bool HasSameSign)
+        : Pred(Pred), HasSameSign(HasSameSign) {}
+
+    PredicateSign(Predicate Pred) : Pred(Pred) {
+      if (isIntPredicate(Pred))
+        HasSameSign = false;
+    }
+
+    operator Predicate() { return Pred; }
+
+    bool hasSameSign() {
+      assert(isIntPredicate(Pred) && HasSameSign);
+      return *HasSameSign;
+    }
+  };
+
 protected:
   CmpInst(Type *ty, Instruction::OtherOps op, Predicate pred, Value *LHS,
           Value *RHS, const Twine &Name = "",
@@ -935,28 +960,29 @@ class CmpInst : public Instruction {
     return isUnsigned(getPredicate());
   }
 
-  /// For example, ULT->SLT, ULE->SLE, UGT->SGT, UGE->SGE, SLT->Failed assert
-  /// @returns the signed version of the unsigned predicate pred.
-  /// return the signed version of a predicate
+  /// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
+  /// @returns the predicate that would be the result if the operand were
+  /// regarded as signed. Asserts on FP predicates.
+  /// Static variant.
   static Predicate getSignedPredicate(Predicate pred);
 
-  /// For example, ULT->SLT, ULE->SLE, UGT->SGT, UGE->SGE, SLT->Failed assert
-  /// @returns the signed version of the predicate for this instruction (which
-  /// has to be an unsigned predicate).
-  /// return the signed version of a predicate
-  Predicate getSignedPredicate() {
+  /// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
+  /// @returns the predicate that would be the result if the operand were
+  /// regarded as signed. Asserts on FP predicates.
+  Predicate getSignedPredicate() const {
     return getSignedPredicate(getPredicate());
   }
 
-  /// For example, SLT->ULT, SLE->ULE, SGT->UGT, SGE->UGE, ULT->Failed assert
-  /// @returns the unsigned version of the signed predicate pred.
+  /// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc.
+  /// @returns the predicate that would be the result if the operand were
+  /// regarded as unsigned. Asserts on FP predicates.
+  /// Static variant.
   static Predicate getUnsignedPredicate(Predicate pred);
 
-  /// For example, SLT->ULT, SLE->ULE, SGT->UGT, SGE->UGE, ULT->Failed assert
-  /// @returns the unsigned version of the predicate for this instruction (which
-  /// has to be an signed predicate).
-  /// return the unsigned version of a predicate
-  Predicate getUnsignedPredicate() {
+  /// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc.
+  /// @returns the predicate that would be the result if the operand were
+  /// regarded as unsigned. Asserts on FP predicates.
+  Predicate getUnsignedPredicate() const {
     return getUnsignedPredicate(getPredicate());
   }
 
@@ -968,7 +994,7 @@ class CmpInst : public Instruction {
   /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert
   /// @returns the unsigned version of the signed predicate pred or
   ///          the signed version of the signed predicate pred.
-  Predicate getFlippedSignednessPredicate() {
+  Predicate getFlippedSignednessPredicate() const {
     return getFlippedSignednessPredicate(getPredicate());
   }
 
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index 8eea659a00caf3..8b1d9a1aa17d82 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1203,29 +1203,17 @@ class ICmpInst: public CmpInst {
 #endif
   }
 
-  /// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
-  /// @returns the predicate that would be the result if the operand were
-  /// regarded as signed.
-  /// Return the signed version of the predicate
-  Predicate getSignedPredicate() const {
-    return getSignedPredicate(getPredicate());
+  PredicateSign getPredicateSign() const {
+    return {getPredicate(), hasSameSign()};
   }
 
-  /// This is a static version that you can use without an instruction.
-  /// Return the signed version of the predicate.
-  static Predicate getSignedPredicate(Predicate pred);
-
-  /// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc.
-  /// @returns the predicate that would be the result if the operand were
-  /// regarded as unsigned.
-  /// Return the unsigned version of the predicate
-  Predicate getUnsignedPredicate() const {
-    return getUnsignedPredicate(getPredicate());
+  PredicateSign getInversePredicateSign() const {
+    return {getInversePredicate(), hasSameSign()};
   }
 
-  /// This is a static version that you can use without an instruction.
-  /// Return the unsigned version of the predicate.
-  static Predicate getUnsignedPredicate(Predicate pred);
+  PredicateSign getSwappedPredicateSign() const {
+    return {getSwappedPredicate(), hasSameSign()};
+  }
 
   void setSameSign(bool B = true) {
     SubclassOptionalData = (SubclassOptionalData & ~SameSign) | (B * SameSign);
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 3075b7ebae59e6..850cf431d6e2e5 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -157,7 +157,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   /// conditional branch or select to create a compare with a canonical
   /// (inverted) predicate which is then more likely to be matched with other
   /// values.
-  static bool isCanonicalPredicate(CmpInst::Predicate Pred) {
+  static bool isCanonicalPredicate(CmpInst::PredicateSign Pred) {
     switch (Pred) {
     case CmpInst::ICMP_NE:
     case CmpInst::ICMP_ULE:
@@ -185,11 +185,12 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   }
 
   std::optional<std::pair<
-      CmpInst::Predicate,
-      Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpInst::
-                                                                       Predicate
-                                                                           Pred,
-                                                                   Constant *C);
+      CmpInst::PredicateSign,
+      Constant
+          *>> static getFlippedStrictnessPredicateAndConstant(CmpInst::
+                                                                  PredicateSign
+                                                                      Pred,
+                                                              Constant *C);
 
   static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
     // a ? b : false and a ? true : b are the canonical form of logical and/or.
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 93b601b22c3a39..a375e0202676db 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -63,10 +63,11 @@ static Value *simplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &,
                             unsigned);
 static Value *simplifyBinOp(unsigned, Value *, Value *, const FastMathFlags &,
                             const SimplifyQuery &, unsigned);
-static Value *simplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &,
-                              unsigned);
-static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
-                               const SimplifyQuery &Q, unsigned MaxRecurse);
+static Value *simplifyCmpInst(CmpInst::PredicateSign, Value *, Value *,
+                              const SimplifyQuery &, unsigned);
+static Value *simplifyICmpInst(CmpInst::PredicateSign Predicate, Value *LHS,
+                               Value *RHS, const SimplifyQuery &Q,
+                               unsigned MaxRecurse);
 static Value *simplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned);
 static Value *simplifyXorInst(Value *, Value *, const SimplifyQuery &,
                               unsigned);
@@ -132,7 +133,7 @@ static Constant *getFalse(Type *Ty) { return ConstantInt::getFalse(Ty); }
 static Constant *getTrue(Type *Ty) { return ConstantInt::getTrue(Ty); }
 
 /// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"?
-static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS,
+static bool isSameCompare(Value *V, CmpInst::PredicateSign Pred, Value *LHS,
                           Value *RHS) {
   CmpInst *Cmp = dyn_cast<CmpInst>(V);
   if (!Cmp)
@@ -150,7 +151,7 @@ static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS,
 ///  %cmp = icmp sle i32 %sel, %rhs
 /// Compose new comparison by substituting %sel with either %tv or %fv
 /// and see if it simplifies.
-static Value *simplifyCmpSelCase(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyCmpSelCase(CmpInst::PredicateSign Pred, Value *LHS,
                                  Value *RHS, Value *Cond,
                                  const SimplifyQuery &Q, unsigned MaxRecurse,
                                  Constant *TrueOrFalse) {
@@ -167,7 +168,7 @@ static Value *simplifyCmpSelCase(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Simplify comparison with true branch of select
-static Value *simplifyCmpSelTrueCase(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyCmpSelTrueCase(CmpInst::PredicateSign Pred, Value *LHS,
                                      Value *RHS, Value *Cond,
                                      const SimplifyQuery &Q,
                                      unsigned MaxRecurse) {
@@ -176,7 +177,7 @@ static Value *simplifyCmpSelTrueCase(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Simplify comparison with false branch of select
-static Value *simplifyCmpSelFalseCase(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyCmpSelFalseCase(CmpInst::PredicateSign Pred, Value *LHS,
                                       Value *RHS, Value *Cond,
                                       const SimplifyQuery &Q,
                                       unsigned MaxRecurse) {
@@ -471,7 +472,7 @@ static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS,
 /// We can simplify %cmp1 to true, because both branches of select are
 /// less than 3. We compose new comparison by substituting %tmp with both
 /// branches of select and see if it can be simplified.
-static Value *threadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS,
+static Value *threadCmpOverSelect(CmpInst::PredicateSign Pred, Value *LHS,
                                   Value *RHS, const SimplifyQuery &Q,
                                   unsigned MaxRecurse) {
   // Recursion is always used, so bail out at once if we already hit the limit.
@@ -564,8 +565,9 @@ static Value *threadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS,
 /// comparison by seeing whether comparing with all of the incoming phi values
 /// yields the same result every time. If so returns the common result,
 /// otherwise returns null.
-static Value *threadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS,
-                               const SimplifyQuery &Q, unsigned MaxRecurse) {
+static Value *threadCmpOverPHI(CmpInst::PredicateSign Pred, Value *LHS,
+                               Value *RHS, const SimplifyQuery &Q,
+                               unsigned MaxRecurse) {
   // Recursion is always used, so bail out at once if we already hit the limit.
   if (!MaxRecurse--)
     return nullptr;
@@ -1001,7 +1003,7 @@ Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
 /// Given a predicate and two operands, return true if the comparison is true.
 /// This is a helper for div/rem simplification where we return some other value
 /// when we can prove a relationship between the operands.
-static bool isICmpTrue(ICmpInst::Predicate Pred, Value *LHS, Value *RHS,
+static bool isICmpTrue(ICmpInst::PredicateSign Pred, Value *LHS, Value *RHS,
                        const SimplifyQuery &Q, unsigned MaxRecurse) {
   Value *V = simplifyICmpInst(Pred, LHS, RHS, Q, MaxRecurse);
   Constant *C = dyn_cast_or_null<Constant>(V);
@@ -2597,7 +2599,7 @@ static Type *getCompareTy(Value *Op) {
 /// Rummage around inside V looking for something equivalent to the comparison
 /// "LHS Pred RHS". Return such a value if found, otherwise return null.
 /// Helper function for analyzing max/min idioms.
-static Value *extractEquivalentCondition(Value *V, CmpInst::Predicate Pred,
+static Value *extractEquivalentCondition(Value *V, CmpInst::PredicateSign Pred,
                                          Value *LHS, Value *RHS) {
   SelectInst *SI = dyn_cast<SelectInst>(V);
   if (!SI)
@@ -2706,7 +2708,7 @@ static bool haveNonOverlappingStorage(const Value *V1, const Value *V2) {
 // If the C and C++ standards are ever made sufficiently restrictive in this
 // area, it may be possible to update LLVM's semantics accordingly and reinstate
 // this optimization.
-static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS,
+static Constant *computePointerICmp(CmpInst::PredicateSign Pred, Value *LHS,
                                     Value *RHS, const SimplifyQuery &Q) {
   assert(LHS->getType() == RHS->getType() && "Must have same types");
   const DataLayout &DL = Q.DL;
@@ -2855,7 +2857,7 @@ static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Fold an icmp when its operands have i1 scalar type.
-static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpOfBools(CmpInst::PredicateSign Pred, Value *LHS,
                                   Value *RHS, const SimplifyQuery &Q) {
   Type *ITy = getCompareTy(LHS); // The return type.
   Type *OpTy = LHS->getType();   // The operand type.
@@ -2958,7 +2960,7 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Try hard to fold icmp with zero RHS because this is a common case.
-static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithZero(CmpInst::PredicateSign Pred, Value *LHS,
                                    Value *RHS, const SimplifyQuery &Q) {
   if (!match(RHS, m_Zero()))
     return nullptr;
@@ -3018,7 +3020,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
   return nullptr;
 }
 
-static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithConstant(CmpInst::PredicateSign Pred, Value *LHS,
                                        Value *RHS, const InstrInfoQuery &IIQ) {
   Type *ITy = getCompareTy(RHS); // The return type.
 
@@ -3066,7 +3068,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
   return nullptr;
 }
 
-static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
+static Value *simplifyICmpWithBinOpOnLHS(CmpInst::PredicateSign Pred,
                                          BinaryOperator *LBO, Value *RHS,
                                          const SimplifyQuery &Q,
                                          unsigned MaxRecurse) {
@@ -3223,7 +3225,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
 // *) C1 < C2 && C1 >= 0, or
 // *) C2 < C1 && C1 <= 0.
 //
-static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS,
+static bool trySimplifyICmpWithAdds(CmpInst::PredicateSign Pred, Value *LHS,
                                     Value *RHS, const InstrInfoQuery &IIQ) {
   // TODO: only support icmp slt for now.
   if (Pred != CmpInst::ICMP_SLT || !IIQ.UseInstrInfo)
@@ -3248,7 +3250,7 @@ static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS,
 /// TODO: A large part of this logic is duplicated in InstCombine's
 /// foldICmpBinOp(). We should be able to share that and avoid the code
 /// duplication.
-static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithBinOp(CmpInst::PredicateSign Pred, Value *LHS,
                                     Value *RHS, const SimplifyQuery &Q,
                                     unsigned MaxRecurse) {
   BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS);
@@ -3482,7 +3484,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
 
 /// simplify integer comparisons where at least one operand of the compare
 /// matches an integer min/max idiom.
-static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithMinMax(CmpInst::PredicateSign Pred, Value *LHS,
                                      Value *RHS, const SimplifyQuery &Q,
                                      unsigned MaxRecurse) {
   Type *ITy = getCompareTy(LHS); // The return type.
@@ -...
[truncated]

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two quick notes without looking into this in detail:

  • Can you please move this out of CmpInst? The fact that CmpInst::Predicate cannot be forward-declared has been a regular source of friction -- we sometimes use unsigned in headers just because of this. We can avoid repeating that mistake now...
  • I'm not a big fan of the name PredicateSign. That sounds like "the sign of the predicate" not "predicate plus samesign"...

@artagnon artagnon force-pushed the ir-cmp-predicatesign branch from f040f49 to b25c331 Compare November 19, 2024 21:57
@artagnon artagnon changed the title IR: introduce CmpInst::PredicateSign IR: introduce struct with CmpInst::Predicate and samesign Nov 19, 2024
@artagnon
Copy link
Contributor Author

* Can you please move this out of CmpInst? The fact that CmpInst::Predicate cannot be forward-declared has been a regular source of friction -- we sometimes use `unsigned` in headers just because of this. We can avoid repeating that mistake now...

Thanks for the context.

* I'm not a big fan of the name `PredicateSign`. That sounds like "the sign of the predicate" not "predicate plus samesign"...

... and here I was thinking I was being clever and elegant migrating CmpInst::Predicate -> CmpInst::PredicateSign.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please rebase?

llvm/include/llvm/Analysis/InstructionSimplify.h Outdated Show resolved Hide resolved
llvm/include/llvm/IR/CmpPredicate.h Outdated Show resolved Hide resolved
@artagnon
Copy link
Contributor Author

artagnon commented Nov 20, 2024

  • Can you please move this out of CmpInst? The fact that CmpInst::Predicate cannot be forward-declared has been a regular source of friction -- we sometimes use unsigned in headers just because of this. We can avoid repeating that mistake now...

Thanks for the context.

Perhaps we can eventually move CmpInst::Predicate into CmpPredicate, once we migrate the entire codebase (do you know if this will have significant compile-time impact? I've marked the functions in CmpPredicate constexpr so they're inlined). I was actually thinking about making functions like getInversePredicate() automatically work with CmpPredicate, but since nothing in IR can be marked virtual, we have no choice but to create CmpPredicate variants of them: I've added "Cmp" to the name, to avoid confusion between a function in CmpInst and an overshadowing one in ICmpInst.

@artagnon artagnon force-pushed the ir-cmp-predicatesign branch from 0bb0888 to 09ff9a0 Compare November 21, 2024 13:13
@artagnon artagnon removed the request for review from MaskRay November 22, 2024 16:57
@artagnon
Copy link
Contributor Author

Gentle ping.

llvm/include/llvm/IR/CmpPredicate.h Outdated Show resolved Hide resolved
llvm/include/llvm/IR/InstrTypes.h Outdated Show resolved Hide resolved
llvm/lib/Analysis/ValueTracking.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp Outdated Show resolved Hide resolved
llvm/include/llvm/IR/Instructions.h Outdated Show resolved Hide resolved
Introduce CmpInst::PredicateSign, an abstraction over a floating-point
predicate, and a pack of an integer predicate with samesign information,
in order to ease extending large portions of the codebase that take a
CmpInst::Predicate to respect the samesign flag.

We have chosen to demonstrate the utility of this new abstraction by
migrating ValueTracking, InstructionSimplify, and InstCombine from
CmpInst::Predicate to CmpInst::PredicateSign. There should be no
functional changes, as we don't perform any extra optimizations with
samesign in this patch.

The design approach taken by this patch allows for unaudited callers of
APIs that take a CmpInst::PredicateSign to silently drop the samesign
information; it does not pose a correctness issue, and allows us to
migrate the codebase piece-wise.
@artagnon artagnon force-pushed the ir-cmp-predicatesign branch from e2aee89 to 94cc891 Compare December 3, 2024 10:28
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as a starting point.

llvm/include/llvm/IR/CmpPredicate.h Outdated Show resolved Hide resolved
llvm/lib/Analysis/InstructionSimplify.cpp Outdated Show resolved Hide resolved
@artagnon artagnon merged commit 51a895a into llvm:main Dec 3, 2024
8 checks passed
@artagnon artagnon deleted the ir-cmp-predicatesign branch December 3, 2024 13:31
@nikic
Copy link
Contributor

nikic commented Dec 3, 2024

@artagnon
Copy link
Contributor Author

artagnon commented Dec 3, 2024

Looks like this has compile-time overhead: https://llvm-compile-time-tracker.com/compare.php?from=f33536468b7f05c05c8cf8088427b0b5b665eb65&to=51a895aded890e90493be59f7af0fa5a3b9b85aa&stat=instructions:u

What choice do we have? I think the samesign feature, by design, will add compile-time overhead to enable throughout the entire compiler flow.

TIFitis pushed a commit to TIFitis/llvm-project that referenced this pull request Dec 18, 2024
Introduce llvm::CmpPredicate, an abstraction over a floating-point
predicate, and a pack of an integer predicate with samesign information,
in order to ease extending large portions of the codebase that take a
CmpInst::Predicate to respect the samesign flag.

We have chosen to demonstrate the utility of this new abstraction by
migrating parts of ValueTracking, InstructionSimplify, and InstCombine
from CmpInst::Predicate to llvm::CmpPredicate. There should be no
functional changes, as we don't perform any extra optimizations with
samesign in this patch, or use CmpPredicate::getMatching.

The design approach taken by this patch allows for unaudited callers of
APIs that take a llvm::CmpPredicate to silently drop the samesign
information; it does not pose a correctness issue, and allows us to
migrate the codebase piece-wise.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants