Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ConstantFPRange] Implement ConstantFPRange::makeAllowedFCmpRegion #110082

Merged
merged 3 commits into from
Oct 2, 2024

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Sep 26, 2024

Note: the return type of makeExactFCmpRegion is changed to std::optional<ConstantFPRange> because I realized that we cannot represent the result of makeExactFCmpRegion(one, X) as a ConstantFPRange.

@llvmbot
Copy link
Member

llvmbot commented Sep 26, 2024

@llvm/pr-subscribers-llvm-ir

Author: Yingwei Zheng (dtcxzyw)

Changes

Note: the return type of makeExactFCmpRegion is changed to std::optional&lt;ConstantFPRange&gt; because I realized that we cannot represent the result of makeExactFCmpRegion(one, X) as a ConstantFPRange.


Full diff: https://github.com/llvm/llvm-project/pull/110082.diff

3 Files Affected:

  • (modified) llvm/include/llvm/IR/ConstantFPRange.h (+12-5)
  • (modified) llvm/lib/IR/ConstantFPRange.cpp (+112-7)
  • (modified) llvm/unittests/IR/ConstantFPRangeTest.cpp (+45)
diff --git a/llvm/include/llvm/IR/ConstantFPRange.h b/llvm/include/llvm/IR/ConstantFPRange.h
index 67f9f945d748ba..cab3a860eaf4ef 100644
--- a/llvm/include/llvm/IR/ConstantFPRange.h
+++ b/llvm/include/llvm/IR/ConstantFPRange.h
@@ -50,7 +50,6 @@ class [[nodiscard]] ConstantFPRange {
 
   void makeEmpty();
   void makeFull();
-  bool isNaNOnly() const;
 
   /// Initialize a full or empty set for the specified semantics.
   explicit ConstantFPRange(const fltSemantics &Sem, bool IsFullSet);
@@ -78,6 +77,9 @@ class [[nodiscard]] ConstantFPRange {
   /// Helper for (-inf, inf) to represent all finite values.
   static ConstantFPRange getFinite(const fltSemantics &Sem);
 
+  /// Helper for [-inf, inf] to represent all non-NaN values.
+  static ConstantFPRange getNonNaN(const fltSemantics &Sem);
+
   /// Create a range which doesn't contain NaNs.
   static ConstantFPRange getNonNaN(APFloat LowerVal, APFloat UpperVal) {
     return ConstantFPRange(std::move(LowerVal), std::move(UpperVal),
@@ -123,8 +125,10 @@ class [[nodiscard]] ConstantFPRange {
   /// { x : fcmp op x y is true}'.
   ///
   /// Example: Pred = olt and Other = float 3 returns [-inf, 3)
-  static ConstantFPRange makeExactFCmpRegion(FCmpInst::Predicate Pred,
-                                             const APFloat &Other);
+  /// If the exact answer is not representable as a ConstantFPRange, return
+  /// std::nullopt.
+  static std::optional<ConstantFPRange>
+  makeExactFCmpRegion(FCmpInst::Predicate Pred, const APFloat &Other);
 
   /// Does the predicate \p Pred hold between ranges this and \p Other?
   /// NOTE: false does not mean that inverse predicate holds!
@@ -139,6 +143,7 @@ class [[nodiscard]] ConstantFPRange {
   bool containsNaN() const { return MayBeQNaN || MayBeSNaN; }
   bool containsQNaN() const { return MayBeQNaN; }
   bool containsSNaN() const { return MayBeSNaN; }
+  bool isNaNOnly() const;
 
   /// Get the semantics of this ConstantFPRange.
   const fltSemantics &getSemantics() const { return Lower.getSemantics(); }
@@ -157,10 +162,12 @@ class [[nodiscard]] ConstantFPRange {
   bool contains(const ConstantFPRange &CR) const;
 
   /// If this set contains a single element, return it, otherwise return null.
-  const APFloat *getSingleElement() const;
+  const APFloat *getSingleElement(bool ExcludesNaN = false) const;
 
   /// Return true if this set contains exactly one member.
-  bool isSingleElement() const { return getSingleElement() != nullptr; }
+  bool isSingleElement(bool ExcludesNaN = false) const {
+    return getSingleElement(ExcludesNaN) != nullptr;
+  }
 
   /// Return true if the sign bit of all values in this range is 1.
   /// Return false if the sign bit of all values in this range is 0.
diff --git a/llvm/lib/IR/ConstantFPRange.cpp b/llvm/lib/IR/ConstantFPRange.cpp
index 957701891c8f37..9f9e4f69a4079d 100644
--- a/llvm/lib/IR/ConstantFPRange.cpp
+++ b/llvm/lib/IR/ConstantFPRange.cpp
@@ -103,11 +103,115 @@ ConstantFPRange ConstantFPRange::getNaNOnly(const fltSemantics &Sem,
                          MayBeSNaN);
 }
 
+ConstantFPRange ConstantFPRange::getNonNaN(const fltSemantics &Sem) {
+  return ConstantFPRange(APFloat::getInf(Sem, /*Negative=*/true),
+                         APFloat::getInf(Sem, /*Negative=*/false),
+                         /*MayBeQNaN=*/false, /*MayBeSNaN=*/false);
+}
+
+/// Return [-inf, V) or [-inf, V]
+static ConstantFPRange makeLessThan(APFloat V, FCmpInst::Predicate Pred) {
+  const fltSemantics &Sem = V.getSemantics();
+  if (!(Pred & FCmpInst::FCMP_OEQ)) {
+    if (V.isNegInfinity())
+      return ConstantFPRange::getEmpty(Sem);
+    V.next(/*nextDown=*/true);
+  }
+  return ConstantFPRange::getNonNaN(APFloat::getInf(Sem, /*Negative=*/true),
+                                    std::move(V));
+}
+
+/// Return (V, +inf] or [V, +inf]
+static ConstantFPRange makeGreaterThan(APFloat V, FCmpInst::Predicate Pred) {
+  const fltSemantics &Sem = V.getSemantics();
+  if (!(Pred & FCmpInst::FCMP_OEQ)) {
+    if (V.isPosInfinity())
+      return ConstantFPRange::getEmpty(Sem);
+    V.next(/*nextDown=*/false);
+  }
+  return ConstantFPRange::getNonNaN(std::move(V),
+                                    APFloat::getInf(Sem, /*Negative=*/false));
+}
+
+/// Make sure that +0/-0 are both included in the range.
+static ConstantFPRange extendZeroIfEqual(const ConstantFPRange &CR,
+                                         FCmpInst::Predicate Pred) {
+  if (!(Pred & FCmpInst::FCMP_OEQ))
+    return CR;
+
+  APFloat Lower = CR.getLower();
+  APFloat Upper = CR.getUpper();
+  if (Lower.isPosZero())
+    Lower = APFloat::getZero(Lower.getSemantics(), /*Negative=*/true);
+  if (Upper.isNegZero())
+    Upper = APFloat::getZero(Upper.getSemantics(), /*Negative=*/false);
+  return ConstantFPRange(std::move(Lower), std::move(Upper), CR.containsQNaN(),
+                         CR.containsSNaN());
+}
+
+static ConstantFPRange setNaNField(const ConstantFPRange &CR,
+                                   FCmpInst::Predicate Pred) {
+  bool ContainsNaN = FCmpInst::isUnordered(Pred);
+  return ConstantFPRange(CR.getLower(), CR.getUpper(),
+                         /*MayBeQNaN=*/ContainsNaN, /*MayBeSNaN=*/ContainsNaN);
+}
+
 ConstantFPRange
 ConstantFPRange::makeAllowedFCmpRegion(FCmpInst::Predicate Pred,
                                        const ConstantFPRange &Other) {
-  // TODO
-  return getFull(Other.getSemantics());
+  if (Other.isEmptySet())
+    return Other;
+  if (Other.containsNaN() && FCmpInst::isUnordered(Pred))
+    return getFull(Other.getSemantics());
+  if (Other.isNaNOnly() && FCmpInst::isOrdered(Pred))
+    return getEmpty(Other.getSemantics());
+
+  switch (Pred) {
+  case FCmpInst::FCMP_TRUE:
+    return getFull(Other.getSemantics());
+  case FCmpInst::FCMP_FALSE:
+    return getEmpty(Other.getSemantics());
+  case FCmpInst::FCMP_ORD:
+    return getNonNaN(Other.getSemantics());
+  case FCmpInst::FCMP_UNO:
+    return getNaNOnly(Other.getSemantics(), /*MayBeQNaN=*/true,
+                      /*MayBeSNaN=*/true);
+  case FCmpInst::FCMP_OEQ:
+  case FCmpInst::FCMP_UEQ:
+    return setNaNField(extendZeroIfEqual(Other, Pred), Pred);
+  case FCmpInst::FCMP_ONE:
+  case FCmpInst::FCMP_UNE:
+    if (const APFloat *SingleElement =
+            Other.getSingleElement(/*ExcludesNaN=*/true)) {
+      const fltSemantics &Sem = SingleElement->getSemantics();
+      if (SingleElement->isPosInfinity())
+        return setNaNField(
+            getNonNaN(APFloat::getInf(Sem, /*Negative=*/true),
+                      APFloat::getLargest(Sem, /*Negative=*/false)),
+            Pred);
+      if (SingleElement->isNegInfinity())
+        return setNaNField(
+            getNonNaN(APFloat::getLargest(Sem, /*Negative=*/true),
+                      APFloat::getInf(Sem, /*Negative=*/false)),
+            Pred);
+    }
+    return Pred == FCmpInst::FCMP_ONE ? getNonNaN(Other.getSemantics())
+                                      : getFull(Other.getSemantics());
+  case FCmpInst::FCMP_OLT:
+  case FCmpInst::FCMP_OLE:
+  case FCmpInst::FCMP_ULT:
+  case FCmpInst::FCMP_ULE:
+    return setNaNField(
+        extendZeroIfEqual(makeLessThan(Other.getUpper(), Pred), Pred), Pred);
+  case FCmpInst::FCMP_OGT:
+  case FCmpInst::FCMP_OGE:
+  case FCmpInst::FCMP_UGT:
+  case FCmpInst::FCMP_UGE:
+    return setNaNField(
+        extendZeroIfEqual(makeGreaterThan(Other.getLower(), Pred), Pred), Pred);
+  default:
+    llvm_unreachable("Unexpected predicate");
+  }
 }
 
 ConstantFPRange
@@ -117,9 +221,10 @@ ConstantFPRange::makeSatisfyingFCmpRegion(FCmpInst::Predicate Pred,
   return getEmpty(Other.getSemantics());
 }
 
-ConstantFPRange ConstantFPRange::makeExactFCmpRegion(FCmpInst::Predicate Pred,
-                                                     const APFloat &Other) {
-  return makeAllowedFCmpRegion(Pred, ConstantFPRange(Other));
+std::optional<ConstantFPRange>
+ConstantFPRange::makeExactFCmpRegion(FCmpInst::Predicate Pred,
+                                     const APFloat &Other) {
+  return std::nullopt;
 }
 
 bool ConstantFPRange::fcmp(FCmpInst::Predicate Pred,
@@ -161,8 +266,8 @@ bool ConstantFPRange::contains(const ConstantFPRange &CR) const {
          strictCompare(CR.Upper, Upper) != APFloat::cmpGreaterThan;
 }
 
-const APFloat *ConstantFPRange::getSingleElement() const {
-  if (MayBeSNaN || MayBeQNaN)
+const APFloat *ConstantFPRange::getSingleElement(bool ExcludesNaN) const {
+  if (!ExcludesNaN && (MayBeSNaN || MayBeQNaN))
     return nullptr;
   return Lower.bitwiseIsEqual(Upper) ? &Lower : nullptr;
 }
diff --git a/llvm/unittests/IR/ConstantFPRangeTest.cpp b/llvm/unittests/IR/ConstantFPRangeTest.cpp
index 722e6566730da5..1fe9231392d622 100644
--- a/llvm/unittests/IR/ConstantFPRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantFPRangeTest.cpp
@@ -161,6 +161,19 @@ static void EnumerateValuesInConstantFPRange(const ConstantFPRange &CR,
   }
 }
 
+template <typename Fn>
+static bool AnyOfValueInConstantFPRange(const ConstantFPRange &CR, Fn TestFn) {
+  const fltSemantics &Sem = CR.getSemantics();
+  unsigned Bits = APFloat::semanticsSizeInBits(Sem);
+  assert(Bits < 32 && "Too many bits");
+  for (unsigned I = 0, E = (1U << Bits) - 1; I != E; ++I) {
+    APFloat V(Sem, APInt(Bits, I));
+    if (CR.contains(V) && TestFn(V))
+      return true;
+  }
+  return false;
+}
+
 TEST_F(ConstantFPRangeTest, Basics) {
   EXPECT_TRUE(Full.isFullSet());
   EXPECT_FALSE(Full.isEmptySet());
@@ -263,12 +276,16 @@ TEST_F(ConstantFPRangeTest, SingleElement) {
   EXPECT_EQ(*One.getSingleElement(), APFloat(1.0));
   EXPECT_EQ(*PosZero.getSingleElement(), APFloat::getZero(Sem));
   EXPECT_EQ(*PosInf.getSingleElement(), APFloat::getInf(Sem));
+  ConstantFPRange PosZeroOrNaN = PosZero.unionWith(NaN);
+  EXPECT_EQ(*PosZeroOrNaN.getSingleElement(/*ExcludesNaN=*/true),
+            APFloat::getZero(Sem));
 
   EXPECT_FALSE(Full.isSingleElement());
   EXPECT_FALSE(Empty.isSingleElement());
   EXPECT_TRUE(One.isSingleElement());
   EXPECT_FALSE(Some.isSingleElement());
   EXPECT_FALSE(Zero.isSingleElement());
+  EXPECT_TRUE(PosZeroOrNaN.isSingleElement(/*ExcludesNaN=*/true));
 }
 
 TEST_F(ConstantFPRangeTest, ExhaustivelyEnumerate) {
@@ -425,4 +442,32 @@ TEST_F(ConstantFPRangeTest, MismatchedSemantics) {
 #endif
 #endif
 
+TEST_F(ConstantFPRangeTest, makeAllowedFCmpRegion) {
+  for (auto Pred : FCmpInst::predicates()) {
+    EnumerateConstantFPRanges(
+        [Pred](const ConstantFPRange &CR) {
+          ConstantFPRange Res =
+              ConstantFPRange::makeAllowedFCmpRegion(Pred, CR);
+          ConstantFPRange Optimal =
+              ConstantFPRange::getEmpty(CR.getSemantics());
+          EnumerateValuesInConstantFPRange(
+              ConstantFPRange::getFull(CR.getSemantics()),
+              [&](const APFloat &V) {
+                if (AnyOfValueInConstantFPRange(CR, [&](const APFloat &U) {
+                      return FCmpInst::compare(V, U, Pred);
+                    }))
+                  Optimal = Optimal.unionWith(ConstantFPRange(V));
+              });
+
+          ASSERT_TRUE(Res.contains(Optimal))
+              << "Wrong result for makeAllowedFCmpRegion(" << Pred << ", " << CR
+              << "). Expected " << Optimal << ", but got " << Res;
+          EXPECT_EQ(Res, Optimal)
+              << "Suboptimal result for makeAllowedFCmpRegion(" << Pred << ", "
+              << CR << ")";
+        },
+        /*Exhaustive=*/false);
+  }
+}
+
 } // anonymous namespace

@dtcxzyw dtcxzyw added the floating-point Floating-point math label Sep 26, 2024
llvm/lib/IR/ConstantFPRange.cpp Outdated Show resolved Hide resolved
llvm/unittests/IR/ConstantFPRangeTest.cpp Outdated Show resolved Hide resolved
dtcxzyw added a commit that referenced this pull request Oct 2, 2024
1. Address post-commit review comments
#86483 (comment).
2. Move some NFC changes from
#110082 to this patch.
@dtcxzyw dtcxzyw force-pushed the cfr-allowed-region branch from 2295fdc to 5da6f23 Compare October 2, 2024 08:09
@dtcxzyw dtcxzyw requested a review from jayfoad October 2, 2024 08:16
puja2196 pushed a commit to puja2196/LLVM-tutorial that referenced this pull request Oct 2, 2024
1. Address post-commit review comments
llvm/llvm-project#86483 (comment).
2. Move some NFC changes from
llvm/llvm-project#110082 to this patch.
VitaNuo pushed a commit to VitaNuo/llvm-project that referenced this pull request Oct 2, 2024
1. Address post-commit review comments
llvm#86483 (comment).
2. Move some NFC changes from
llvm#110082 to this patch.
VitaNuo pushed a commit to VitaNuo/llvm-project that referenced this pull request Oct 2, 2024
1. Address post-commit review comments
llvm#86483 (comment).
2. Move some NFC changes from
llvm#110082 to this patch.
@dtcxzyw dtcxzyw merged commit 5867362 into llvm:main Oct 2, 2024
6 of 8 checks passed
@dtcxzyw dtcxzyw deleted the cfr-allowed-region branch October 2, 2024 12:44
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
1. Address post-commit review comments
llvm#86483 (comment).
2. Move some NFC changes from
llvm#110082 to this patch.
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
…lvm#110082)

This patch adds `makeAllowedFCmpRegion` support for `ConstantFPRange`.
@vporpo
Copy link
Contributor

vporpo commented Oct 3, 2024

I just noticed that ConstantFPRangeTest.makeAllowedFCmpRegion takes a very long time to run, like 8 minutes, and slows down check-llvm. PTAL.

@aeubanks
Copy link
Contributor

aeubanks commented Oct 3, 2024

I just noticed that ConstantFPRangeTest.makeAllowedFCmpRegion takes a very long time to run, like 8 minutes, and slows down check-llvm. PTAL.

iterating over all possible float values in AnyOfValueInConstantFPRange is not good for test times

@arsenm
Copy link
Contributor

arsenm commented Oct 3, 2024

iterating over all possible float values in AnyOfValueInConstantFPRange is not good for test times

It's supposed to be all fp8 values, which is a much smaller range

@aeubanks
Copy link
Contributor

aeubanks commented Oct 3, 2024

I think the problem is that we're doing that nested 3 deep?

dtcxzyw added a commit that referenced this pull request Oct 4, 2024
…es in a range (#111083)

NaN payloads can be ignored because they are unrelated with
ConstantFPRange (except the conversion from ConstantFPRange to
KnownBits). This patch just enumerates `+/-[S/Q]NaN` to avoid
enumerating 32 NaN values in all ranges which contain NaN values.
Addresses comment
#110082 (comment).
This patch reduces the execution time for unittests from 30.37s to
10.59s with an optimized build.
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 4, 2024
1. Address post-commit review comments
llvm#86483 (comment).
2. Move some NFC changes from
llvm#110082 to this patch.
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 4, 2024
…lvm#110082)

This patch adds `makeAllowedFCmpRegion` support for `ConstantFPRange`.
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 4, 2024
…es in a range (llvm#111083)

NaN payloads can be ignored because they are unrelated with
ConstantFPRange (except the conversion from ConstantFPRange to
KnownBits). This patch just enumerates `+/-[S/Q]NaN` to avoid
enumerating 32 NaN values in all ranges which contain NaN values.
Addresses comment
llvm#110082 (comment).
This patch reduces the execution time for unittests from 30.37s to
10.59s with an optimized build.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
floating-point Floating-point math llvm:ir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants