Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[libc][math][c23] Add MPFR exhaustive test for fmodf16 #94656

Merged
merged 4 commits into from
Jun 25, 2024

Conversation

overmighty
Copy link
Member

No description provided.

@overmighty
Copy link
Member Author

overmighty commented Jun 6, 2024

Based on #94629 and #94383.

@overmighty overmighty force-pushed the libc-math-fmodf16-exhaustive-test branch from 7e0de47 to bd989f9 Compare June 24, 2024 21:03
@overmighty
Copy link
Member Author

Rebased, resolved merge conflicts, and refactored the changes in exhaustive_test.h a lot.

@overmighty overmighty marked this pull request as ready for review June 24, 2024 21:07
@overmighty
Copy link
Member Author

cc @lntue

@llvmbot llvmbot added the libc label Jun 24, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2024

@llvm/pr-subscribers-libc

Author: OverMighty (overmighty)

Changes

Patch is 22.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94656.diff

5 Files Affected:

  • (modified) libc/test/src/math/exhaustive/CMakeLists.txt (+19)
  • (modified) libc/test/src/math/exhaustive/exhaustive_test.h (+108-14)
  • (added) libc/test/src/math/exhaustive/fmodf16_test.cpp (+41)
  • (modified) libc/utils/MPFRWrapper/MPFRUtils.cpp (+59-31)
  • (modified) libc/utils/MPFRWrapper/MPFRUtils.h (+28-11)
diff --git a/libc/test/src/math/exhaustive/CMakeLists.txt b/libc/test/src/math/exhaustive/CMakeLists.txt
index 34df8720ed4db..fb3596c3378ff 100644
--- a/libc/test/src/math/exhaustive/CMakeLists.txt
+++ b/libc/test/src/math/exhaustive/CMakeLists.txt
@@ -4,6 +4,10 @@ add_header_library(
   exhaustive_test
   HDRS
     exhaustive_test.h
+  DEPENDS
+    libc.src.__support.CPP.type_traits
+    libc.src.__support.FPUtil.fp_bits
+    libc.src.__support.macros.properties.types
 )
 
 add_fp_unittest(
@@ -277,6 +281,21 @@ add_fp_unittest(
     libc.src.__support.FPUtil.generic.fmod
 )
 
+add_fp_unittest(
+  fmodf16_test
+  NO_RUN_POSTBUILD
+  NEED_MPFR
+  SUITE
+    libc_math_exhaustive_tests
+  SRCS
+    fmodf16_test.cpp
+  DEPENDS
+    .exhaustive_test
+    libc.src.math.fmodf16
+  LINK_LIBRARIES
+    -lpthread
+)
+
 add_fp_unittest(
   coshf_test
   NO_RUN_POSTBUILD
diff --git a/libc/test/src/math/exhaustive/exhaustive_test.h b/libc/test/src/math/exhaustive/exhaustive_test.h
index 13e272783250b..bf9be90f39e2c 100644
--- a/libc/test/src/math/exhaustive/exhaustive_test.h
+++ b/libc/test/src/math/exhaustive/exhaustive_test.h
@@ -8,6 +8,7 @@
 
 #include "src/__support/CPP/type_traits.h"
 #include "src/__support/FPUtil/FPBits.h"
+#include "src/__support/macros/properties/types.h"
 #include "test/UnitTest/FPMatcher.h"
 #include "test/UnitTest/Test.h"
 #include "utils/MPFRWrapper/MPFRUtils.h"
@@ -68,9 +69,46 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
   }
 };
 
+template <typename OutType, typename InType = OutType>
+using BinaryOp = OutType(InType, InType);
+
+template <typename OutType, typename InType, mpfr::Operation Op,
+          BinaryOp<OutType, InType> Func>
+struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
+  using FloatType = InType;
+  using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
+  using StorageType = typename FPBits::StorageType;
+
+  // Check in a range, return the number of failures.
+  uint64_t check(StorageType x_start, StorageType x_stop, StorageType y_start,
+                 StorageType y_stop, mpfr::RoundingMode rounding) {
+    mpfr::ForceRoundingMode r(rounding);
+    if (!r.success)
+      return x_stop > x_start || y_stop > y_start;
+    StorageType xbits = x_start;
+    uint64_t failed = 0;
+    do {
+      FloatType x = FPBits(xbits).get_val();
+      StorageType ybits = y_start;
+      do {
+        FloatType y = FPBits(ybits).get_val();
+        mpfr::BinaryInput<FloatType> input{x, y};
+        bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, input, Func(x, y),
+                                                         0.5, rounding);
+        failed += (!correct);
+        // Uncomment to print out failed values.
+        // if (!correct) {
+        //   EXPECT_MPFR_MATCH_ROUNDING(Op, input, Func(x, y), 0.5, rounding);
+        // }
+      } while (ybits++ < y_stop);
+    } while (xbits++ < x_stop);
+    return failed;
+  }
+};
+
 // Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide
 //   StorageType and check method.
-template <typename Checker>
+template <typename Checker, size_t Increment = 1 << 20>
 struct LlvmLibcExhaustiveMathTest
     : public virtual LIBC_NAMESPACE::testing::Test,
       public Checker {
@@ -78,12 +116,37 @@ struct LlvmLibcExhaustiveMathTest
   using FPBits = typename Checker::FPBits;
   using StorageType = typename Checker::StorageType;
 
-  static constexpr StorageType INCREMENT = (1 << 20);
+  static constexpr StorageType INCREMENT = Increment;
+
+  void explain_failed_range(std::stringstream &msg, StorageType x_begin,
+                            StorageType x_end) {
+#ifdef LIBC_TYPES_HAS_FLOAT16
+    using T = LIBC_NAMESPACE::cpp::conditional_t<
+        LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float, FloatType>;
+#else
+    using T = FloatType;
+#endif
+
+    msg << x_begin << " to " << x_end << " [0x" << std::hex << x_begin << ", 0x"
+        << x_end << "), [" << std::hexfloat
+        << static_cast<T>(FPBits(x_begin).get_val()) << ", "
+        << static_cast<T>(FPBits(x_end).get_val()) << ")";
+  }
+
+  void explain_failed_range(std::stringstream &msg, StorageType x_begin,
+                            StorageType x_end, StorageType y_begin,
+                            StorageType y_end) {
+    msg << "x ";
+    explain_failed_range(msg, x_begin, x_end);
+    msg << ", y ";
+    explain_failed_range(msg, y_begin, y_end);
+  }
 
   // Break [start, stop) into `nthreads` subintervals and apply *check to each
   // subinterval in parallel.
-  void test_full_range(StorageType start, StorageType stop,
-                       mpfr::RoundingMode rounding) {
+  template <typename... T>
+  void test_full_range(mpfr::RoundingMode rounding, StorageType start,
+                       StorageType stop, T... extra_range_bounds) {
     int n_threads = std::thread::hardware_concurrency();
     std::vector<std::thread> thread_list;
     std::mutex mx_cur_val;
@@ -120,15 +183,14 @@ struct LlvmLibcExhaustiveMathTest
             std::cout << msg.str() << std::flush;
           }
 
-          uint64_t failed_in_range =
-              Checker::check(range_begin, range_end, rounding);
+          uint64_t failed_in_range = Checker::check(
+              range_begin, range_end, extra_range_bounds..., rounding);
           if (failed_in_range > 0) {
             std::stringstream msg;
             msg << "Test failed for " << std::dec << failed_in_range
-                << " inputs in range: " << range_begin << " to " << range_end
-                << " [0x" << std::hex << range_begin << ", 0x" << range_end
-                << "), [" << std::hexfloat << FPBits(range_begin).get_val()
-                << ", " << FPBits(range_end).get_val() << ")\n";
+                << " inputs in range: ";
+            explain_failed_range(msg, start, stop, extra_range_bounds...);
+            msg << "\n";
             std::cerr << msg.str() << std::flush;
 
             failed.fetch_add(failed_in_range);
@@ -151,19 +213,46 @@ struct LlvmLibcExhaustiveMathTest
   void test_full_range_all_roundings(StorageType start, StorageType stop) {
     std::cout << "-- Testing for FE_TONEAREST in range [0x" << std::hex << start
               << ", 0x" << stop << ") --" << std::dec << std::endl;
-    test_full_range(start, stop, mpfr::RoundingMode::Nearest);
+    test_full_range(mpfr::RoundingMode::Nearest, start, stop);
 
     std::cout << "-- Testing for FE_UPWARD in range [0x" << std::hex << start
               << ", 0x" << stop << ") --" << std::dec << std::endl;
-    test_full_range(start, stop, mpfr::RoundingMode::Upward);
+    test_full_range(mpfr::RoundingMode::Upward, start, stop);
 
     std::cout << "-- Testing for FE_DOWNWARD in range [0x" << std::hex << start
               << ", 0x" << stop << ") --" << std::dec << std::endl;
-    test_full_range(start, stop, mpfr::RoundingMode::Downward);
+    test_full_range(mpfr::RoundingMode::Downward, start, stop);
 
     std::cout << "-- Testing for FE_TOWARDZERO in range [0x" << std::hex
               << start << ", 0x" << stop << ") --" << std::dec << std::endl;
-    test_full_range(start, stop, mpfr::RoundingMode::TowardZero);
+    test_full_range(mpfr::RoundingMode::TowardZero, start, stop);
+  };
+
+  void test_full_range_all_roundings(StorageType x_start, StorageType x_stop,
+                                     StorageType y_start, StorageType y_stop) {
+    std::cout << "-- Testing for FE_TONEAREST in x range [0x" << std::hex
+              << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
+              << ", 0x" << y_stop << ") --" << std::dec << std::endl;
+    test_full_range(mpfr::RoundingMode::Nearest, x_start, x_stop, y_start,
+                    y_stop);
+
+    std::cout << "-- Testing for FE_UPWARD in x range [0x" << std::hex
+              << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
+              << ", 0x" << y_stop << ") --" << std::dec << std::endl;
+    test_full_range(mpfr::RoundingMode::Upward, x_start, x_stop, y_start,
+                    y_stop);
+
+    std::cout << "-- Testing for FE_DOWNWARD in x range [0x" << std::hex
+              << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
+              << ", 0x" << y_stop << ") --" << std::dec << std::endl;
+    test_full_range(mpfr::RoundingMode::Downward, x_start, x_stop, y_start,
+                    y_stop);
+
+    std::cout << "-- Testing for FE_TOWARDZERO in x range [0x" << std::hex
+              << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
+              << ", 0x" << y_stop << ") --" << std::dec << std::endl;
+    test_full_range(mpfr::RoundingMode::TowardZero, x_start, x_stop, y_start,
+                    y_stop);
   };
 };
 
@@ -175,3 +264,8 @@ template <typename OutType, typename InType, mpfr::Operation Op,
           UnaryOp<OutType, InType> Func>
 using LlvmLibcUnaryNarrowingOpExhaustiveMathTest =
     LlvmLibcExhaustiveMathTest<UnaryOpChecker<OutType, InType, Op, Func>>;
+
+template <typename FloatType, mpfr::Operation Op, BinaryOp<FloatType> Func>
+using LlvmLibcBinaryOpExhaustiveMathTest =
+    LlvmLibcExhaustiveMathTest<BinaryOpChecker<FloatType, FloatType, Op, Func>,
+                               1 << 2>;
diff --git a/libc/test/src/math/exhaustive/fmodf16_test.cpp b/libc/test/src/math/exhaustive/fmodf16_test.cpp
new file mode 100644
index 0000000000000..067cec969a4f7
--- /dev/null
+++ b/libc/test/src/math/exhaustive/fmodf16_test.cpp
@@ -0,0 +1,41 @@
+//===-- Exhaustive test for fmodf16 ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "exhaustive_test.h"
+#include "src/math/fmodf16.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
+
+using LlvmLibcFmodf16ExhaustiveTest =
+    LlvmLibcBinaryOpExhaustiveMathTest<float16, mpfr::Operation::Fmod,
+                                       LIBC_NAMESPACE::fmodf16>;
+
+// Range: [0, Inf];
+static constexpr uint16_t POS_START = 0x0000U;
+static constexpr uint16_t POS_STOP = 0x7c00U;
+
+// Range: [-Inf, 0];
+static constexpr uint16_t NEG_START = 0x8000U;
+static constexpr uint16_t NEG_STOP = 0xfc00U;
+
+TEST_F(LlvmLibcFmodf16ExhaustiveTest, PostivePositiveRange) {
+  test_full_range_all_roundings(POS_START, POS_STOP, POS_START, POS_STOP);
+}
+
+TEST_F(LlvmLibcFmodf16ExhaustiveTest, PostiveNegativeRange) {
+  test_full_range_all_roundings(POS_START, POS_STOP, NEG_START, NEG_STOP);
+}
+
+TEST_F(LlvmLibcFmodf16ExhaustiveTest, NegativePositiveRange) {
+  test_full_range_all_roundings(NEG_START, NEG_STOP, POS_START, POS_STOP);
+}
+
+TEST_F(LlvmLibcFmodf16ExhaustiveTest, NegativeNegativeRange) {
+  test_full_range_all_roundings(NEG_START, NEG_STOP, POS_START, POS_STOP);
+}
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 2eac4dd8e199d..88aef3e6e10c5 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -296,6 +296,12 @@ class MPFRNumber {
     return result;
   }
 
+  MPFRNumber div(const MPFRNumber &b) const {
+    MPFRNumber result(*this);
+    mpfr_div(result.value, value, b.value, mpfr_rounding);
+    return result;
+  }
+
   MPFRNumber floor() const {
     MPFRNumber result(*this);
     mpfr_floor(result.value, value);
@@ -708,6 +714,8 @@ binary_operation_one_output(Operation op, InputType x, InputType y,
   switch (op) {
   case Operation::Atan2:
     return inputX.atan2(inputY);
+  case Operation::Div:
+    return inputX.div(inputY);
   case Operation::Fmod:
     return inputX.fmod(inputY);
   case Operation::Hypot:
@@ -885,42 +893,49 @@ template void explain_binary_operation_two_outputs_error<long double>(
     Operation, const BinaryInput<long double> &,
     const BinaryOutput<long double> &, double, RoundingMode);
 
-template <typename T>
-void explain_binary_operation_one_output_error(Operation op,
-                                               const BinaryInput<T> &input,
-                                               T libc_result,
-                                               double ulp_tolerance,
-                                               RoundingMode rounding) {
-  unsigned int precision = get_precision<T>(ulp_tolerance);
+template <typename InputType, typename OutputType>
+void explain_binary_operation_one_output_error(
+    Operation op, const BinaryInput<InputType> &input, OutputType libc_result,
+    double ulp_tolerance, RoundingMode rounding) {
+  unsigned int precision = get_precision<InputType>(ulp_tolerance);
   MPFRNumber mpfrX(input.x, precision);
   MPFRNumber mpfrY(input.y, precision);
-  FPBits<T> xbits(input.x);
-  FPBits<T> ybits(input.y);
+  FPBits<InputType> xbits(input.x);
+  FPBits<InputType> ybits(input.y);
   MPFRNumber mpfr_result =
       binary_operation_one_output(op, input.x, input.y, precision, rounding);
   MPFRNumber mpfrMatchValue(libc_result);
 
   tlog << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n';
-  tlog << "First input bits: " << str(FPBits<T>(input.x)) << '\n';
-  tlog << "Second input bits: " << str(FPBits<T>(input.y)) << '\n';
+  tlog << "First input bits: " << str(FPBits<InputType>(input.x)) << '\n';
+  tlog << "Second input bits: " << str(FPBits<InputType>(input.y)) << '\n';
 
   tlog << "Libc result: " << mpfrMatchValue.str() << '\n'
        << "MPFR result: " << mpfr_result.str() << '\n';
-  tlog << "Libc floating point result bits: " << str(FPBits<T>(libc_result))
-       << '\n';
+  tlog << "Libc floating point result bits: "
+       << str(FPBits<OutputType>(libc_result)) << '\n';
   tlog << "              MPFR rounded bits: "
-       << str(FPBits<T>(mpfr_result.as<T>())) << '\n';
+       << str(FPBits<OutputType>(mpfr_result.as<OutputType>())) << '\n';
   tlog << "ULP error: " << mpfr_result.ulp_as_mpfr_number(libc_result).str()
        << '\n';
 }
 
-template void explain_binary_operation_one_output_error<float>(
-    Operation, const BinaryInput<float> &, float, double, RoundingMode);
-template void explain_binary_operation_one_output_error<double>(
+template void
+explain_binary_operation_one_output_error(Operation, const BinaryInput<float> &,
+                                          float, double, RoundingMode);
+template void explain_binary_operation_one_output_error(
     Operation, const BinaryInput<double> &, double, double, RoundingMode);
-template void explain_binary_operation_one_output_error<long double>(
-    Operation, const BinaryInput<long double> &, long double, double,
-    RoundingMode);
+template void
+explain_binary_operation_one_output_error(Operation,
+                                          const BinaryInput<long double> &,
+                                          long double, double, RoundingMode);
+#ifdef LIBC_TYPES_HAS_FLOAT16
+template void explain_binary_operation_one_output_error(
+    Operation, const BinaryInput<float16> &, float16, double, RoundingMode);
+template void
+explain_binary_operation_one_output_error(Operation, const BinaryInput<float> &,
+                                          float16, double, RoundingMode);
+#endif
 
 template <typename InputType, typename OutputType>
 void explain_ternary_operation_one_output_error(
@@ -1051,12 +1066,13 @@ template bool compare_binary_operation_two_outputs<long double>(
     Operation, const BinaryInput<long double> &,
     const BinaryOutput<long double> &, double, RoundingMode);
 
-template <typename T>
+template <typename InputType, typename OutputType>
 bool compare_binary_operation_one_output(Operation op,
-                                         const BinaryInput<T> &input,
-                                         T libc_result, double ulp_tolerance,
+                                         const BinaryInput<InputType> &input,
+                                         OutputType libc_result,
+                                         double ulp_tolerance,
                                          RoundingMode rounding) {
-  unsigned int precision = get_precision<T>(ulp_tolerance);
+  unsigned int precision = get_precision<InputType>(ulp_tolerance);
   MPFRNumber mpfr_result =
       binary_operation_one_output(op, input.x, input.y, precision, rounding);
   double ulp = mpfr_result.ulp(libc_result);
@@ -1064,13 +1080,25 @@ bool compare_binary_operation_one_output(Operation op,
   return (ulp <= ulp_tolerance);
 }
 
-template bool compare_binary_operation_one_output<float>(
-    Operation, const BinaryInput<float> &, float, double, RoundingMode);
-template bool compare_binary_operation_one_output<double>(
-    Operation, const BinaryInput<double> &, double, double, RoundingMode);
-template bool compare_binary_operation_one_output<long double>(
-    Operation, const BinaryInput<long double> &, long double, double,
-    RoundingMode);
+template bool compare_binary_operation_one_output(Operation,
+                                                  const BinaryInput<float> &,
+                                                  float, double, RoundingMode);
+template bool compare_binary_operation_one_output(Operation,
+                                                  const BinaryInput<double> &,
+                                                  double, double, RoundingMode);
+template bool
+compare_binary_operation_one_output(Operation, const BinaryInput<long double> &,
+                                    long double, double, RoundingMode);
+#ifdef LIBC_TYPES_HAS_FLOAT16
+template bool compare_binary_operation_one_output(Operation,
+                                                  const BinaryInput<float16> &,
+                                                  float16, double,
+                                                  RoundingMode);
+template bool compare_binary_operation_one_output(Operation,
+                                                  const BinaryInput<float> &,
+                                                  float16, double,
+                                                  RoundingMode);
+#endif
 
 template <typename InputType, typename OutputType>
 bool compare_ternary_operation_one_output(Operation op,
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index 0b4f42a72ec81..46f3375fd4b7e 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -71,6 +71,7 @@ enum class Operation : int {
   // output.
   BeginBinaryOperationsSingleOutput,
   Atan2,
+  Div,
   Fmod,
   Hypot,
   Pow,
@@ -129,6 +130,14 @@ struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
   static constexpr bool VALUE = cpp::is_floating_point_v<T>;
 };
 
+template <typename T> struct IsBinaryInput {
+  static constexpr bool VALUE = false;
+};
+
+template <typename T> struct IsBinaryInput<BinaryInput<T>> {
+  static constexpr bool VALUE = true;
+};
+
 template <typename T> struct IsTernaryInput {
   static constexpr bool VALUE = false;
 };
@@ -139,6 +148,9 @@ template <typename T> struct IsTernaryInput<TernaryInput<T>> {
 
 template <typename T> struct MakeScalarInput : cpp::type_identity<T> {};
 
+template <typename T>
+struct MakeScalarInput<BinaryInput<T>> : cpp::type_identity<T> {};
+
 template <typename T>
 struct MakeScalarInput<TernaryInput<T>> : cpp::type_identity<T> {};
 
@@ -159,10 +171,11 @@ bool compare_binary_operation_two_outputs(Operation op,
                                           double ulp_tolerance,
                                           RoundingMode rounding);
 
-template <typename T>
+template <typename InputType, typename OutputType>
 bool compare_binary_operation_one_output(Operation op,
-                                         const BinaryInput<T> &input,
-                                         T libc_output, double ulp_tolerance,
+                                         const BinaryInput<InputType> &input,
+                                         OutputType libc_output,
+                                         double ulp_tolerance,
                                          RoundingMode rounding);
 
 template <typename InputType, typename OutputType>
@@ -187,12 +200,10 @@ void explain_binary_operation_tw...
[truncated]

@overmighty overmighty force-pushed the libc-math-fmodf16-exhaustive-test branch from b75b526 to 6894d5c Compare June 25, 2024 14:28
@overmighty
Copy link
Member Author

Rebased.

@lntue lntue merged commit ef05b03 into llvm:main Jun 25, 2024
6 checks passed
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants