Skip to content

Commit

Permalink
Cubic roots
Browse files Browse the repository at this point in the history
  • Loading branch information
NAThompson committed Oct 23, 2021
1 parent b0d1e4f commit c692c20
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 0 deletions.
130 changes: 130 additions & 0 deletions include/boost/math/tools/cubic_roots.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// (C) Copyright Nick Thompson 2019.
// Use, modification and distribution are subject to the
// Boost Software License, Version 1.0. (See accompanying file
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
#ifndef BOOST_MATH_TOOLS_CUBIC_ROOTS_HPP
#define BOOST_MATH_TOOLS_CUBIC_ROOTS_HPP
#include <array>
#include <algorithm>
#include <boost/math/tools/roots.hpp>

namespace boost::math::tools {

namespace detail {
template <typename Real> int sgn(Real val) {
return (Real(0) < val) - (val < Real(0));
}
}
// Solves ax³ + bx² + cx + d = 0.
// Only returns the real roots, as types get weird for real coefficients and complex roots.
// Follows Numerical Recipes, Chapter 5, section 6.
template<typename Real>
std::array<Real, 3> cubic_roots(Real a, Real b, Real c, Real d) {
using std::sqrt;
using std::acos;
using std::cos;
using std::cbrt;
using std::abs;
using std::fma;
std::array<Real, 3> roots = {std::numeric_limits<Real>::quiet_NaN(),
std::numeric_limits<Real>::quiet_NaN(),
std::numeric_limits<Real>::quiet_NaN()};
if (a == 0) {
// bx^2 + cx + d = 0:
if (b == 0) {
// cx + d = 0:
if (c == 0) {
if (d != 0) {
// No solutions:
return roots;
}
roots[0] = 0;
roots[1] = 0;
roots[2] = 0;
return roots;
}
roots[0] = -d/c;
return roots;
}
auto [x0, x1] = quadratic_roots(b, c, d);
roots[0] = x0;
roots[1] = x1;
return roots;
}
if (d == 0) {
auto [x0, x1] = quadratic_roots(a, b, c);
roots[0] = x0;
roots[1] = x1;
roots[2] = 0;
std::sort(roots.begin(), roots.end());
return roots;
}
Real p = b/a;
Real q = c/a;
Real r = d/a;
Real Q = (p*p - 3*q)/9;
Real R = (2*p*p*p - 9*p*q + 27*r)/54;
if (R*R < Q*Q*Q) {
//std::cout << "In the R^2 < Q^3 branch.\n";
Real rtQ = sqrt(Q);
Real theta = acos(R/(Q*rtQ))/3;
Real st = sin(theta);
Real ct = cos(theta);
roots[0] = -2*rtQ*ct - p/3;
roots[1] = -rtQ*(-ct + sqrt(Real(3))*st) - p/3;
roots[2] = rtQ*(ct + sqrt(Real(3))*st) - p/3;
// This formula is not super accurate.
// Do a cleanup iteration.
for (auto & r : roots) {
// Horner's method.
// Here I'll take John Gustaffson's opinion that the fma is a *distinct* operation from a*x +b:
// Make sure to compile these fmas into a single instruction!
Real f = fma(a, r, b);
f = fma(f,r,c);
f = fma(f,r,d);
Real df = fma(3*a, r, 2*b);
df = fma(df, r, c);
if (df != 0) {
// No standard library feature for fused-divide add!
r -= f/df;
}
}
std::sort(roots.begin(), roots.end());
return roots;
}
// In Numerical Recipes, Chapter 5, Section 6, it is claimed that we only have one real root
// if R^2 >= Q^3. But this isn't true; we can even see this from equation 5.6.18.
// The condition for having three real roots is that A = B.
// It *is* the case that if we're in this branch, and we have 3 real roots, two are a double root.
// Take (x+1)^2(x-2) = x^3 - 3x -2 as an example. This clearly has a double root at x = -1,
// and it gets sent into this branch.
Real arg = R*R - Q*Q*Q;
Real A = -detail::sgn(R)*cbrt(abs(R) + sqrt(arg));
Real B = 0;
if (A != 0) {
B = Q/A;
}
roots[0] = A + B - p/3;
// Yes, we're comparing floats for equality:
// Any perturbation pushes the roots into the complex plane; out of the bailiwick of this routine.
if (A == B || arg == 0) {
roots[1] = -A - p/3;
roots[2] = -A - p/3;
}
for (auto & r : roots) {
Real f = fma(a, r, b);
f = fma(f,r,c);
f = fma(f,r,d);
Real df = fma(3*a, r, 2*b);
df = fma(df, r, c);
if (df != 0) {
// No standard library feature for fused-divide add!
r -= f/df;
}
}
std::sort(roots.begin(), roots.end());
return roots;
}

}
#endif
41 changes: 41 additions & 0 deletions reporting/performance/cubic_roots_performance.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// (C) Copyright Nick Thompson 2021.
// Use, modification and distribution are subject to the
// Boost Software License, Version 1.0. (See accompanying file
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

#include <random>
#include <array>
#include <vector>
#include <iostream>
#include <benchmark/benchmark.h>
#include <boost/math/tools/cubic_roots.hpp>

using boost::math::tools::cubic_roots;

template<class Real>
void CubicRoots(benchmark::State& state)
{
std::random_device rd;
//auto seed = rd();
uint32_t seed = 416683252;
std::mt19937_64 mt(seed);
std::uniform_real_distribution<Real> unif(-10, 10);

Real a = unif(mt);
Real b = unif(mt);
Real c = unif(mt);
Real d = unif(mt);
for (auto _ : state)
{
auto roots = cubic_roots(a,b,c,d);
benchmark::DoNotOptimize(roots[0]);
}
std::cout << "Just solved " << a << "x^3 + " << b << "x^2 + " << c << "x + " << d << "\n";
std::cout << "This was generated by seed " << seed << "\n";
}

//BENCHMARK_TEMPLATE(CubicRoots, float);
BENCHMARK_TEMPLATE(CubicRoots, double);
//BENCHMARK_TEMPLATE(CubicRoots, long double);

BENCHMARK_MAIN();
127 changes: 127 additions & 0 deletions test/cubic_roots_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright Nick Thompson, 2021
* Use, modification and distribution are subject to the
* Boost Software License, Version 1.0. (See accompanying file
* LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
*/

#include "math_unit_test.hpp"
#include <random>
#include <boost/math/tools/cubic_roots.hpp>
#ifdef BOOST_HAS_FLOAT128
#include <boost/multiprecision/float128.hpp>
using boost::multiprecision::float128;
#endif

using boost::math::tools::cubic_roots;
using std::cbrt;

template<class Real>
void test_zero_coefficients()
{
Real a = 0;
Real b = 0;
Real c = 0;
Real d = 0;
auto roots = cubic_roots(a,b,c,d);
CHECK_EQUAL(roots[0], Real(0));
CHECK_EQUAL(roots[1], Real(0));
CHECK_EQUAL(roots[2], Real(0));

a = 1;
roots = cubic_roots(a,b,c,d);
CHECK_EQUAL(roots[0], Real(0));
CHECK_EQUAL(roots[1], Real(0));
CHECK_EQUAL(roots[2], Real(0));

a = 1;
d = 1;
// x^3 + 1 = 0:
roots = cubic_roots(a,b,c,d);
CHECK_EQUAL(roots[0], Real(-1));
CHECK_NAN(roots[1]);
CHECK_NAN(roots[2]);
d = -1;
// x^3 - 1 = 0:
roots = cubic_roots(a,b,c,d);
CHECK_EQUAL(roots[0], Real(1));
CHECK_NAN(roots[1]);
CHECK_NAN(roots[2]);

d = -2;
// x^3 - 2 = 0
roots = cubic_roots(a,b,c,d);
// Let's go for equality here!
CHECK_EQUAL(roots[0], cbrt(Real(2)));
CHECK_NAN(roots[1]);
CHECK_NAN(roots[2]);

d = -8;
roots = cubic_roots(a,b,c,d);
CHECK_EQUAL(roots[0], Real(2));
CHECK_NAN(roots[1]);
CHECK_NAN(roots[2]);


// (x-1)(x-2)(x-3) = x^3 - 6x^2 + 11x - 6
roots = cubic_roots(Real(1), Real(-6), Real(11), Real(-6));
CHECK_ULP_CLOSE(roots[0], Real(1), 2);
CHECK_ULP_CLOSE(roots[1], Real(2), 2);
CHECK_ULP_CLOSE(roots[2], Real(3), 2);

// Double root:
// (x+1)^2(x-2) = x^3 - 3x - 2:
// Note: This test is unstable wrt to perturbations!
roots = cubic_roots(Real(1), Real(0), Real(-3), Real(-2));
CHECK_ULP_CLOSE(-1, roots[0], 2);
CHECK_ULP_CLOSE(-1, roots[1], 2);
CHECK_ULP_CLOSE(2, roots[2], 2);

std::uniform_real_distribution<Real> dis(-2,2);
std::mt19937 gen(12345);
// Expected roots
std::array<Real, 3> r;
int trials = 10;
for (int i = 0; i < trials; ++i) {
// Mathematica:
// Expand[(x - r0)*(x - r1)*(x - r2)]
// - r0 r1 r2 + (r0 r1 + r0 r2 + r1 r2) x
// - (r0 + r1 + r2) x^2 + x^3
for (auto & root : r) {
root = static_cast<Real>(dis(gen));
}
std::sort(r.begin(), r.end());
Real a = 1;
Real b = -(r[0] + r[1] + r[2]);
Real c = r[0]*r[1] + r[0]*r[2] + r[1]*r[2];
Real d = -r[0]*r[1]*r[2];

auto roots = cubic_roots(a, b, c, d);
// I could check the condition number here, but this is fine right?
if(!CHECK_ULP_CLOSE(r[0], roots[0], 3)) {
std::cerr << " Polynomial x^3 + " << b << "x^2 + " << c << "x + " << d << " has roots {";
std::cerr << r[0] << ", " << r[1] << ", " << r[2] << "}, but the computed roots are {";
std::cerr << roots[0] << ", " << roots[1] << ", " << roots[2] << "}\n";
}
CHECK_ULP_CLOSE(r[1], roots[1], 3);
CHECK_ULP_CLOSE(r[2], roots[2], 3);
}
}


int main()
{
test_zero_coefficients<float>();
test_zero_coefficients<double>();
#ifndef BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS
test_zero_coefficients<long double>();
#endif

#ifdef BOOST_HAS_FLOAT128
// For some reason, the quadmath is way less accurate than the float/double/long double:
//test_zero_coefficients<float128>();
#endif


return boost::math::test::report_errors();
}

0 comments on commit c692c20

Please sign in to comment.