-
Notifications
You must be signed in to change notification settings - Fork 226
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b0d1e4f
commit c692c20
Showing
3 changed files
with
298 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
// (C) Copyright Nick Thompson 2019. | ||
// Use, modification and distribution are subject to the | ||
// Boost Software License, Version 1.0. (See accompanying file | ||
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) | ||
#ifndef BOOST_MATH_TOOLS_CUBIC_ROOTS_HPP | ||
#define BOOST_MATH_TOOLS_CUBIC_ROOTS_HPP | ||
#include <array> | ||
#include <algorithm> | ||
#include <boost/math/tools/roots.hpp> | ||
|
||
namespace boost::math::tools { | ||
|
||
namespace detail { | ||
template <typename Real> int sgn(Real val) { | ||
return (Real(0) < val) - (val < Real(0)); | ||
} | ||
} | ||
// Solves ax³ + bx² + cx + d = 0. | ||
// Only returns the real roots, as types get weird for real coefficients and complex roots. | ||
// Follows Numerical Recipes, Chapter 5, section 6. | ||
template<typename Real> | ||
std::array<Real, 3> cubic_roots(Real a, Real b, Real c, Real d) { | ||
using std::sqrt; | ||
using std::acos; | ||
using std::cos; | ||
using std::cbrt; | ||
using std::abs; | ||
using std::fma; | ||
std::array<Real, 3> roots = {std::numeric_limits<Real>::quiet_NaN(), | ||
std::numeric_limits<Real>::quiet_NaN(), | ||
std::numeric_limits<Real>::quiet_NaN()}; | ||
if (a == 0) { | ||
// bx^2 + cx + d = 0: | ||
if (b == 0) { | ||
// cx + d = 0: | ||
if (c == 0) { | ||
if (d != 0) { | ||
// No solutions: | ||
return roots; | ||
} | ||
roots[0] = 0; | ||
roots[1] = 0; | ||
roots[2] = 0; | ||
return roots; | ||
} | ||
roots[0] = -d/c; | ||
return roots; | ||
} | ||
auto [x0, x1] = quadratic_roots(b, c, d); | ||
roots[0] = x0; | ||
roots[1] = x1; | ||
return roots; | ||
} | ||
if (d == 0) { | ||
auto [x0, x1] = quadratic_roots(a, b, c); | ||
roots[0] = x0; | ||
roots[1] = x1; | ||
roots[2] = 0; | ||
std::sort(roots.begin(), roots.end()); | ||
return roots; | ||
} | ||
Real p = b/a; | ||
Real q = c/a; | ||
Real r = d/a; | ||
Real Q = (p*p - 3*q)/9; | ||
Real R = (2*p*p*p - 9*p*q + 27*r)/54; | ||
if (R*R < Q*Q*Q) { | ||
//std::cout << "In the R^2 < Q^3 branch.\n"; | ||
Real rtQ = sqrt(Q); | ||
Real theta = acos(R/(Q*rtQ))/3; | ||
Real st = sin(theta); | ||
Real ct = cos(theta); | ||
roots[0] = -2*rtQ*ct - p/3; | ||
roots[1] = -rtQ*(-ct + sqrt(Real(3))*st) - p/3; | ||
roots[2] = rtQ*(ct + sqrt(Real(3))*st) - p/3; | ||
// This formula is not super accurate. | ||
// Do a cleanup iteration. | ||
for (auto & r : roots) { | ||
// Horner's method. | ||
// Here I'll take John Gustaffson's opinion that the fma is a *distinct* operation from a*x +b: | ||
// Make sure to compile these fmas into a single instruction! | ||
Real f = fma(a, r, b); | ||
f = fma(f,r,c); | ||
f = fma(f,r,d); | ||
Real df = fma(3*a, r, 2*b); | ||
df = fma(df, r, c); | ||
if (df != 0) { | ||
// No standard library feature for fused-divide add! | ||
r -= f/df; | ||
} | ||
} | ||
std::sort(roots.begin(), roots.end()); | ||
return roots; | ||
} | ||
// In Numerical Recipes, Chapter 5, Section 6, it is claimed that we only have one real root | ||
// if R^2 >= Q^3. But this isn't true; we can even see this from equation 5.6.18. | ||
// The condition for having three real roots is that A = B. | ||
// It *is* the case that if we're in this branch, and we have 3 real roots, two are a double root. | ||
// Take (x+1)^2(x-2) = x^3 - 3x -2 as an example. This clearly has a double root at x = -1, | ||
// and it gets sent into this branch. | ||
Real arg = R*R - Q*Q*Q; | ||
Real A = -detail::sgn(R)*cbrt(abs(R) + sqrt(arg)); | ||
Real B = 0; | ||
if (A != 0) { | ||
B = Q/A; | ||
} | ||
roots[0] = A + B - p/3; | ||
// Yes, we're comparing floats for equality: | ||
// Any perturbation pushes the roots into the complex plane; out of the bailiwick of this routine. | ||
if (A == B || arg == 0) { | ||
roots[1] = -A - p/3; | ||
roots[2] = -A - p/3; | ||
} | ||
for (auto & r : roots) { | ||
Real f = fma(a, r, b); | ||
f = fma(f,r,c); | ||
f = fma(f,r,d); | ||
Real df = fma(3*a, r, 2*b); | ||
df = fma(df, r, c); | ||
if (df != 0) { | ||
// No standard library feature for fused-divide add! | ||
r -= f/df; | ||
} | ||
} | ||
std::sort(roots.begin(), roots.end()); | ||
return roots; | ||
} | ||
|
||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// (C) Copyright Nick Thompson 2021. | ||
// Use, modification and distribution are subject to the | ||
// Boost Software License, Version 1.0. (See accompanying file | ||
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) | ||
|
||
#include <random> | ||
#include <array> | ||
#include <vector> | ||
#include <iostream> | ||
#include <benchmark/benchmark.h> | ||
#include <boost/math/tools/cubic_roots.hpp> | ||
|
||
using boost::math::tools::cubic_roots; | ||
|
||
template<class Real> | ||
void CubicRoots(benchmark::State& state) | ||
{ | ||
std::random_device rd; | ||
//auto seed = rd(); | ||
uint32_t seed = 416683252; | ||
std::mt19937_64 mt(seed); | ||
std::uniform_real_distribution<Real> unif(-10, 10); | ||
|
||
Real a = unif(mt); | ||
Real b = unif(mt); | ||
Real c = unif(mt); | ||
Real d = unif(mt); | ||
for (auto _ : state) | ||
{ | ||
auto roots = cubic_roots(a,b,c,d); | ||
benchmark::DoNotOptimize(roots[0]); | ||
} | ||
std::cout << "Just solved " << a << "x^3 + " << b << "x^2 + " << c << "x + " << d << "\n"; | ||
std::cout << "This was generated by seed " << seed << "\n"; | ||
} | ||
|
||
//BENCHMARK_TEMPLATE(CubicRoots, float); | ||
BENCHMARK_TEMPLATE(CubicRoots, double); | ||
//BENCHMARK_TEMPLATE(CubicRoots, long double); | ||
|
||
BENCHMARK_MAIN(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/* | ||
* Copyright Nick Thompson, 2021 | ||
* Use, modification and distribution are subject to the | ||
* Boost Software License, Version 1.0. (See accompanying file | ||
* LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) | ||
*/ | ||
|
||
#include "math_unit_test.hpp" | ||
#include <random> | ||
#include <boost/math/tools/cubic_roots.hpp> | ||
#ifdef BOOST_HAS_FLOAT128 | ||
#include <boost/multiprecision/float128.hpp> | ||
using boost::multiprecision::float128; | ||
#endif | ||
|
||
using boost::math::tools::cubic_roots; | ||
using std::cbrt; | ||
|
||
template<class Real> | ||
void test_zero_coefficients() | ||
{ | ||
Real a = 0; | ||
Real b = 0; | ||
Real c = 0; | ||
Real d = 0; | ||
auto roots = cubic_roots(a,b,c,d); | ||
CHECK_EQUAL(roots[0], Real(0)); | ||
CHECK_EQUAL(roots[1], Real(0)); | ||
CHECK_EQUAL(roots[2], Real(0)); | ||
|
||
a = 1; | ||
roots = cubic_roots(a,b,c,d); | ||
CHECK_EQUAL(roots[0], Real(0)); | ||
CHECK_EQUAL(roots[1], Real(0)); | ||
CHECK_EQUAL(roots[2], Real(0)); | ||
|
||
a = 1; | ||
d = 1; | ||
// x^3 + 1 = 0: | ||
roots = cubic_roots(a,b,c,d); | ||
CHECK_EQUAL(roots[0], Real(-1)); | ||
CHECK_NAN(roots[1]); | ||
CHECK_NAN(roots[2]); | ||
d = -1; | ||
// x^3 - 1 = 0: | ||
roots = cubic_roots(a,b,c,d); | ||
CHECK_EQUAL(roots[0], Real(1)); | ||
CHECK_NAN(roots[1]); | ||
CHECK_NAN(roots[2]); | ||
|
||
d = -2; | ||
// x^3 - 2 = 0 | ||
roots = cubic_roots(a,b,c,d); | ||
// Let's go for equality here! | ||
CHECK_EQUAL(roots[0], cbrt(Real(2))); | ||
CHECK_NAN(roots[1]); | ||
CHECK_NAN(roots[2]); | ||
|
||
d = -8; | ||
roots = cubic_roots(a,b,c,d); | ||
CHECK_EQUAL(roots[0], Real(2)); | ||
CHECK_NAN(roots[1]); | ||
CHECK_NAN(roots[2]); | ||
|
||
|
||
// (x-1)(x-2)(x-3) = x^3 - 6x^2 + 11x - 6 | ||
roots = cubic_roots(Real(1), Real(-6), Real(11), Real(-6)); | ||
CHECK_ULP_CLOSE(roots[0], Real(1), 2); | ||
CHECK_ULP_CLOSE(roots[1], Real(2), 2); | ||
CHECK_ULP_CLOSE(roots[2], Real(3), 2); | ||
|
||
// Double root: | ||
// (x+1)^2(x-2) = x^3 - 3x - 2: | ||
// Note: This test is unstable wrt to perturbations! | ||
roots = cubic_roots(Real(1), Real(0), Real(-3), Real(-2)); | ||
CHECK_ULP_CLOSE(-1, roots[0], 2); | ||
CHECK_ULP_CLOSE(-1, roots[1], 2); | ||
CHECK_ULP_CLOSE(2, roots[2], 2); | ||
|
||
std::uniform_real_distribution<Real> dis(-2,2); | ||
std::mt19937 gen(12345); | ||
// Expected roots | ||
std::array<Real, 3> r; | ||
int trials = 10; | ||
for (int i = 0; i < trials; ++i) { | ||
// Mathematica: | ||
// Expand[(x - r0)*(x - r1)*(x - r2)] | ||
// - r0 r1 r2 + (r0 r1 + r0 r2 + r1 r2) x | ||
// - (r0 + r1 + r2) x^2 + x^3 | ||
for (auto & root : r) { | ||
root = static_cast<Real>(dis(gen)); | ||
} | ||
std::sort(r.begin(), r.end()); | ||
Real a = 1; | ||
Real b = -(r[0] + r[1] + r[2]); | ||
Real c = r[0]*r[1] + r[0]*r[2] + r[1]*r[2]; | ||
Real d = -r[0]*r[1]*r[2]; | ||
|
||
auto roots = cubic_roots(a, b, c, d); | ||
// I could check the condition number here, but this is fine right? | ||
if(!CHECK_ULP_CLOSE(r[0], roots[0], 3)) { | ||
std::cerr << " Polynomial x^3 + " << b << "x^2 + " << c << "x + " << d << " has roots {"; | ||
std::cerr << r[0] << ", " << r[1] << ", " << r[2] << "}, but the computed roots are {"; | ||
std::cerr << roots[0] << ", " << roots[1] << ", " << roots[2] << "}\n"; | ||
} | ||
CHECK_ULP_CLOSE(r[1], roots[1], 3); | ||
CHECK_ULP_CLOSE(r[2], roots[2], 3); | ||
} | ||
} | ||
|
||
|
||
int main() | ||
{ | ||
test_zero_coefficients<float>(); | ||
test_zero_coefficients<double>(); | ||
#ifndef BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS | ||
test_zero_coefficients<long double>(); | ||
#endif | ||
|
||
#ifdef BOOST_HAS_FLOAT128 | ||
// For some reason, the quadmath is way less accurate than the float/double/long double: | ||
//test_zero_coefficients<float128>(); | ||
#endif | ||
|
||
|
||
return boost::math::test::report_errors(); | ||
} |