Skip to content

Commit

Permalink
Quartic roots.
Browse files Browse the repository at this point in the history
  • Loading branch information
NAThompson authored and Nick Thompson committed Dec 30, 2021
1 parent 7108ccc commit 8937001
Show file tree
Hide file tree
Showing 6 changed files with 389 additions and 10 deletions.
40 changes: 40 additions & 0 deletions doc/roots/quartic_roots.qbk
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[/
Copyright (c) 2021 Nick Thompson
Use, modification and distribution are subject to the
Boost Software License, Version 1.0. (See accompanying file
LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
]

[section:quartic_roots Roots of Quartic Polynomials]

[heading Synopsis]

```
#include <boost/math/roots/quartic_roots.hpp>

namespace boost::math::tools {

// Solves ax⁴ + bx³ + cx² + dx + e = 0.
std::array<Real,3> quartic_roots(Real a, Real b, Real c, Real d, Real e);

}
```

[heading Background]

The `quartic_roots` function extracts all real roots of a quartic polynomial ax^4 + bx³ + cx² + dx + e.
The result is a `std::array<Real, 4>`, which has length four, irrespective of the number of real roots the polynomial possesses.
(This is to prevent the performance overhead of allocating a vector, which often exceeds the time to extract the roots.)
The roots are returned in nondecreasing order. If a root is complex, then it is placed at the back of the array and set to a nan.

[@https://en.wikipedia.org/wiki/Hardy_space Hardy space]
The algorithm uses the classical method of Ferrari, and follows [@https://github.com/erich666/GraphicsGems/blob/master/gems/Roots3And4.c Graphics Gems V],
with an additional Halley iterate for root polishing.

[heading Performance and Accuracy]

On a consumer laptop, we observe extraction of the roots taking ~90ns.
The file `reporting/performance/quartic_roots_performance.cpp` allows determination of the speed on your system.

[endsect]
[/section:quartic_roots]
1 change: 1 addition & 0 deletions doc/roots/roots_overview.qbk
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ There are several fully-worked __root_finding_examples, including:
[include roots_without_derivatives.qbk]
[include roots.qbk]
[include cubic_roots.qbk]
[include quartic_roots.qbk]
[include root_finding_examples.qbk]
[include minima.qbk]
[include root_comparison.qbk]
Expand Down
149 changes: 149 additions & 0 deletions include/boost/math/tools/quartic_roots.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// (C) Copyright Nick Thompson 2021.
// Use, modification and distribution are subject to the
// Boost Software License, Version 1.0. (See accompanying file
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
#ifndef BOOST_MATH_TOOLS_QUARTIC_ROOTS_HPP
#define BOOST_MATH_TOOLS_QUARTIC_ROOTS_HPP
#include <array>
#include <cmath>
#include <boost/math/tools/cubic_roots.hpp>

namespace boost::math::tools {

namespace detail {

// Make sure the nans are always at the back of the array:
template<typename Real>
bool comparator(Real r1, Real r2) {
if (std::isnan(r2)) { return true; }
return r1 < r2;
}

template<typename Real>
std::array<Real, 4> polish_and_sort(Real a, Real b, Real c, Real d, Real e, std::array<Real, 4>& roots) {
// Polish the roots with a Halley iterate.
using std::fma;
for (auto &r : roots) {
Real df = fma(4*a, r, 3*b);
df = fma(df, r, 2*c);
df = fma(df, r, d);
Real d2f = fma(12*a, r, 6*b);
d2f = fma(d2f, r, 2*c);
Real f = fma(a, r, b);
f = fma(f,r,c);
f = fma(f,r,d);
f = fma(f,r,e);
Real denom = 2*df*df - f*d2f;
if (std::abs(denom) > std::numeric_limits<Real>::min())
{
r -= 2*f*df/denom;
}
}
std::sort(roots.begin(), roots.end(), detail::comparator<Real>);
return roots;
}

}
// Solves ax⁴ + bx³ + cx² + dx + e = 0.
// Only returns the real roots, as these are the only roots of interest in ray intersection problems.
// Follows Graphics Gems V: https://github.com/erich666/GraphicsGems/blob/master/gems/Roots3And4.c
template<typename Real>
std::array<Real, 4> quartic_roots(Real a, Real b, Real c, Real d, Real e) {
using std::abs;
using std::sqrt;
auto nan = std::numeric_limits<Real>::quiet_NaN();
std::array<Real, 4> roots{nan, nan, nan, nan};
if (std::abs(a) <= std::numeric_limits<Real>::min()) {
auto cbrts = cubic_roots(b, c, d, e);
roots[0] = cbrts[0];
roots[1] = cbrts[1];
roots[2] = cbrts[2];
if (b == 0 && c == 0 && d == 0 && e == 0) {
roots[3] = 0;
}
return detail::polish_and_sort(a, b, c, d, e, roots);
}
if (std::abs(e) <= std::numeric_limits<Real>::min()) {
auto v = cubic_roots(a, b, c, d);
roots[0] = v[0];
roots[1] = v[1];
roots[2] = v[2];
roots[3] = 0;
return detail::polish_and_sort(a, b, c, d, e, roots);
}
// Now solve x⁴ + Ax³ + Bx² + Cx + D = 0.
Real A = b/a;
Real B = c/a;
Real C = d/a;
Real D = e/a;
Real Asq = A*A;
// Let x = y - A/4:
// Mathematica: Expand[(y - A/4)^4 + A*(y - A/4)^3 + B*(y - A/4)^2 + C*(y - A/4) + D]
// We now solve the depressed quartic y⁴ + py² + qy + r = 0.
Real p = B - 3*Asq/8;
Real q = C - A*B/2 + Asq*A/8;
Real r = D - A*C/4 + Asq*B/16 - 3*Asq*Asq/256;
if (std::abs(r) <= std::numeric_limits<Real>::min()) {
auto [r1, r2, r3] = cubic_roots(Real(1), Real(0), p, q);
r1 -= A/4;
r2 -= A/4;
r3 -= A/4;
roots[0] = r1;
roots[1] = r2;
roots[2] = r3;
roots[3] = -A/4;
return detail::polish_and_sort(a, b, c, d, e, roots);
}
// Biquadratic case:
if (std::abs(q) <= std::numeric_limits<Real>::min()) {
auto [r1, r2] = quadratic_roots(Real(1), p, r);
std::vector<Real> v;
if (r1 >= 0) {
Real rtr = sqrt(r1);
roots[0] = rtr - A/4;
roots[1] = -rtr - A/4;
}
if (r2 >= 0) {
Real rtr = sqrt(r2);
roots[2] = rtr - A/4;
roots[3] = -rtr - A/4;
}
return detail::polish_and_sort(a, b, c, d, e, roots);
}

// Now split the depressed quartic into two quadratics:
// y⁴ + py² + qy + r = (y² + sy + u)(y² - sy + v) = y⁴ + (v+u-s²)y² + s(v - u)y + uv
// So p = v+u-s², q = s(v - u), r = uv.
// Then (v+u)² - (v-u)² = 4uv = 4r = (p+s²)² - q²/s².
// Multiply through by s² to get s²(p+s²)² - q² - 4rs² = 0, which is a cubic in s².
// Then we let z = s², to get
// z³ + 2pz² + (p² - 4r)z - q² = 0.
auto z_roots = cubic_roots(Real(1), 2*p, p*p - 4*r, -q*q);
// z = s², so s = sqrt(z).
// No real roots:
if (z_roots.back() <= 0) {
return roots;
}
Real s = std::sqrt(z_roots.back());

// s is nonzero, because we took care of the biquadratic case.
Real v = (p + s*s + q/s)/2;
Real u = v - q/s;
// Now solve y² + sy + u = 0:
auto [root0, root1] = quadratic_roots(Real(1), s, u);

// Now solve y² - sy + v = 0:
auto [root2, root3] = quadratic_roots(Real(1), -s, v);
roots[0] = root0;
roots[1] = root1;
roots[2] = root2;
roots[3] = root3;

for (auto& r : roots) {
r -= A/4;
}
return detail::polish_and_sort(a, b, c, d, e, roots);
}

}
#endif
41 changes: 41 additions & 0 deletions reporting/performance/quartic_roots_performance.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// (C) Copyright Nick Thompson 2021.
// Use, modification and distribution are subject to the
// Boost Software License, Version 1.0. (See accompanying file
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

#include <random>
#include <array>
#include <vector>
#include <iostream>
#include <benchmark/benchmark.h>
#include <boost/math/tools/quartic_roots.hpp>

using boost::math::tools::quartic_roots;

template<class Real>
void QuarticRoots(benchmark::State& state)
{
std::random_device rd;
auto seed = rd();
// This seed generates 3 real roots:
//uint32_t seed = 416683252;
std::mt19937_64 mt(seed);
std::uniform_real_distribution<Real> unif(-10, 10);

Real a = unif(mt);
Real b = unif(mt);
Real c = unif(mt);
Real d = unif(mt);
Real e = unif(mt);
for (auto _ : state)
{
auto roots = quartic_roots(a,b,c,d, e);
benchmark::DoNotOptimize(roots[0]);
}
}

BENCHMARK_TEMPLATE(QuarticRoots, float);
BENCHMARK_TEMPLATE(QuarticRoots, double);
BENCHMARK_TEMPLATE(QuarticRoots, long double);

BENCHMARK_MAIN();
26 changes: 16 additions & 10 deletions test/math_unit_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@
#include <boost/math/tools/assert.hpp>
#include <boost/math/special_functions/next.hpp>
#include <boost/math/special_functions/trunc.hpp>
#include <boost/core/demangle.hpp>

#include <cxxabi.h>
namespace boost { namespace math { namespace test {

namespace detail {
static std::atomic<int64_t> global_error_count{0};
static std::atomic<int64_t> total_ulp_distance{0};

inline std::string demangle(char const * name)
{
int status = 0;
std::size_t size = 0;
return abi::__cxa_demangle( name, NULL, &size, &status );
}
}

template<class Real>
Expand Down Expand Up @@ -49,7 +55,7 @@ bool check_mollified_close(Real expected, Real computed, Real tol, std::string c
std::ios_base::fmtflags f( std::cerr.flags() );
std::cerr << std::setprecision(3);
std::cerr << "\033[0;31mError at " << filename << ":" << function << ":" << line << ":\n"
<< " \033[0m Mollified relative error in " << boost::core::demangle(typeid(Real).name())<< " precision is " << mollified_relative_error
<< " \033[0m Mollified relative error in " << detail::demangle(typeid(Real).name())<< " precision is " << mollified_relative_error
<< ", which exceeds " << tol << ", error/tol = " << mollified_relative_error/tol << ".\n"
<< std::setprecision(std::numeric_limits<Real>::max_digits10) << std::showpos
<< " Expected: " << std::defaultfloat << std::fixed << expected << std::hexfloat << " = " << expected << "\n"
Expand Down Expand Up @@ -77,8 +83,8 @@ bool check_ulp_close(PreciseReal expected1, Real computed, size_t ulps, std::str
if (sizeof(PreciseReal) < sizeof(Real)) {
std::ostringstream err;
err << "\n\tThe expected number must be computed in higher (or equal) precision than the number being tested.\n";
err << "\tType of expected is " << boost::core::demangle(typeid(PreciseReal).name()) << ", which occupies " << sizeof(PreciseReal) << " bytes.\n";
err << "\tType of computed is " << boost::core::demangle(typeid(Real).name()) << ", which occupies " << sizeof(Real) << " bytes.\n";
err << "\tType of expected is " << detail::demangle(typeid(PreciseReal).name()) << ", which occupies " << sizeof(PreciseReal) << " bytes.\n";
err << "\tType of computed is " << detail::demangle(typeid(Real).name()) << ", which occupies " << sizeof(Real) << " bytes.\n";
throw std::logic_error(err.str());
}
}
Expand All @@ -105,7 +111,7 @@ bool check_ulp_close(PreciseReal expected1, Real computed, size_t ulps, std::str
std::ios_base::fmtflags f( std::cerr.flags() );
std::cerr << std::setprecision(3);
std::cerr << "\033[0;31mError at " << filename << ":" << function << ":" << line << ":\n"
<< " \033[0m ULP distance in " << boost::core::demangle(typeid(Real).name())<< " precision is " << dist
<< " \033[0m ULP distance in " << detail::demangle(typeid(Real).name())<< " precision is " << dist
<< ", which exceeds " << ulps;
if (ulps > 0)
{
Expand Down Expand Up @@ -161,7 +167,7 @@ bool check_le(Real lesser, Real greater, std::string const & filename, std::stri
std::ios_base::fmtflags f( std::cerr.flags() );
std::cerr << std::setprecision(3);
std::cerr << "\033[0;31mError at " << filename << ":" << function << ":" << line << ":\n"
<< " \033[0m Condition " << lesser << " \u2264 " << greater << " is violated in " << boost::core::demangle(typeid(Real).name()) << " precision.\n";
<< " \033[0m Condition " << lesser << " \u2264 " << greater << " is violated in " << detail::demangle(typeid(Real).name()) << " precision.\n";
std::cerr << std::setprecision(std::numeric_limits<Real>::max_digits10) << std::showpos
<< " \"Lesser\" : " << std::defaultfloat << std::fixed << lesser << " = " << std::scientific << lesser << std::hexfloat << " = " << lesser << "\n"
<< " \"Greater\": " << std::defaultfloat << std::fixed << greater << " = " << std::scientific << greater << std::hexfloat << " = " << greater << "\n"
Expand Down Expand Up @@ -214,7 +220,7 @@ bool check_conditioned_error(Real abscissa, PreciseReal expected1, PreciseReal e
std::cerr << "\033[0;31mError at " << filename << ":" << function << ":" << line << ":\n";
std::cerr << std::setprecision(std::numeric_limits<Real>::max_digits10) << std::showpos;
std::cerr << "\033[0m Error at abscissa " << std::defaultfloat << std::fixed << abscissa << " = " << std::hexfloat << abscissa << "\n";
std::cerr << " Given that the expected value is zero, the computed value in " << boost::core::demangle(typeid(Real).name()) << " precision must satisfy |f(x)| <= " << tol << ".\n";
std::cerr << " Given that the expected value is zero, the computed value in " << detail::demangle(typeid(Real).name()) << " precision must satisfy |f(x)| <= " << tol << ".\n";
std::cerr << " But the computed value is " << std::defaultfloat << std::fixed << computed << std::hexfloat << " = " << computed << "\n";
std::cerr.flags(f);
++detail::global_error_count;
Expand All @@ -240,7 +246,7 @@ bool check_conditioned_error(Real abscissa, PreciseReal expected1, PreciseReal e
std::cerr << "\033[0;31mError at " << filename << ":" << function << ":" << line << "\n";
std::cerr << std::setprecision(std::numeric_limits<Real>::max_digits10);
std::cerr << "\033[0m The relative error at abscissa x = " << std::defaultfloat << std::fixed << abscissa << " = " << std::hexfloat << abscissa
<< " in " << boost::core::demangle(typeid(Real).name()) << " precision is " << std::scientific << relative_error << "\n"
<< " in " << detail::demangle(typeid(Real).name()) << " precision is " << std::scientific << relative_error << "\n"
<< " This exceeds the tolerance " << tol << "\n"
<< std::showpos
<< " Expected: " << std::defaultfloat << std::fixed << expected << " = " << std::scientific << expected << std::hexfloat << " = " << expected << "\n"
Expand Down Expand Up @@ -288,7 +294,7 @@ bool check_absolute_error(PreciseReal expected1, Real computed, Real acceptable_
std::cerr << std::setprecision(3);
std::cerr << "\033[0;31mError at " << filename << ":" << function << ":" << line << "\n";
std::cerr << std::setprecision(std::numeric_limits<Real>::max_digits10);
std::cerr << "\033[0m The absolute error in " << boost::core::demangle(typeid(Real).name()) << " precision is " << std::scientific << error << "\n"
std::cerr << "\033[0m The absolute error in " << detail::demangle(typeid(Real).name()) << " precision is " << std::scientific << error << "\n"
<< " This exceeds the acceptable error " << acceptable_error << "\n"
<< std::showpos
<< " Expected: " << std::defaultfloat << std::fixed << expected << " = " << std::scientific << expected << std::hexfloat << " = " << expected << "\n"
Expand Down
Loading

0 comments on commit 8937001

Please sign in to comment.