quartic_roots.hpp 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. // (C) Copyright Nick Thompson 2021.
  2. // Use, modification and distribution are subject to the
  3. // Boost Software License, Version 1.0. (See accompanying file
  4. // LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  5. #ifndef BOOST_MATH_TOOLS_QUARTIC_ROOTS_HPP
  6. #define BOOST_MATH_TOOLS_QUARTIC_ROOTS_HPP
  7. #include <array>
  8. #include <cmath>
  9. #include <boost/math/tools/cubic_roots.hpp>
  10. namespace boost::math::tools {
  11. namespace detail {
  12. // Make sure the nans are always at the back of the array:
  13. template<typename Real>
  14. bool comparator(Real r1, Real r2) {
  15. using std::isnan;
  16. if (isnan(r1)) { return false; }
  17. if (isnan(r2)) { return true; }
  18. return r1 < r2;
  19. }
  20. template<typename Real>
  21. std::array<Real, 4> polish_and_sort(Real a, Real b, Real c, Real d, Real e, std::array<Real, 4>& roots) {
  22. // Polish the roots with a Halley iterate.
  23. using std::fma;
  24. using std::abs;
  25. for (auto &r : roots) {
  26. Real df = fma(4*a, r, 3*b);
  27. df = fma(df, r, 2*c);
  28. df = fma(df, r, d);
  29. Real d2f = fma(12*a, r, 6*b);
  30. d2f = fma(d2f, r, 2*c);
  31. Real f = fma(a, r, b);
  32. f = fma(f,r,c);
  33. f = fma(f,r,d);
  34. f = fma(f,r,e);
  35. Real denom = 2*df*df - f*d2f;
  36. if (abs(denom) > (std::numeric_limits<Real>::min)())
  37. {
  38. r -= 2*f*df/denom;
  39. }
  40. }
  41. std::sort(roots.begin(), roots.end(), detail::comparator<Real>);
  42. return roots;
  43. }
  44. }
  45. // Solves ax^4 + bx^3 + cx^2 + dx + e = 0.
  46. // Only returns the real roots, as these are the only roots of interest in ray intersection problems.
  47. // Follows Graphics Gems V: https://github.com/erich666/GraphicsGems/blob/master/gems/Roots3And4.c
  48. template<typename Real>
  49. std::array<Real, 4> quartic_roots(Real a, Real b, Real c, Real d, Real e) {
  50. using std::abs;
  51. using std::sqrt;
  52. auto nan = std::numeric_limits<Real>::quiet_NaN();
  53. std::array<Real, 4> roots{nan, nan, nan, nan};
  54. if (abs(a) <= (std::numeric_limits<Real>::min)()) {
  55. auto cbrts = cubic_roots(b, c, d, e);
  56. roots[0] = cbrts[0];
  57. roots[1] = cbrts[1];
  58. roots[2] = cbrts[2];
  59. if (b == 0 && c == 0 && d == 0 && e == 0) {
  60. roots[3] = 0;
  61. }
  62. return detail::polish_and_sort(a, b, c, d, e, roots);
  63. }
  64. if (abs(e) <= (std::numeric_limits<Real>::min)()) {
  65. auto v = cubic_roots(a, b, c, d);
  66. roots[0] = v[0];
  67. roots[1] = v[1];
  68. roots[2] = v[2];
  69. roots[3] = 0;
  70. return detail::polish_and_sort(a, b, c, d, e, roots);
  71. }
  72. // Now solve x^4 + Ax^3 + Bx^2 + Cx + D = 0.
  73. Real A = b/a;
  74. Real B = c/a;
  75. Real C = d/a;
  76. Real D = e/a;
  77. Real Asq = A*A;
  78. // Let x = y - A/4:
  79. // Mathematica: Expand[(y - A/4)^4 + A*(y - A/4)^3 + B*(y - A/4)^2 + C*(y - A/4) + D]
  80. // We now solve the depressed quartic y^4 + py^2 + qy + r = 0.
  81. Real p = B - 3*Asq/8;
  82. Real q = C - A*B/2 + Asq*A/8;
  83. Real r = D - A*C/4 + Asq*B/16 - 3*Asq*Asq/256;
  84. if (abs(r) <= (std::numeric_limits<Real>::min)()) {
  85. auto [r1, r2, r3] = cubic_roots(Real(1), Real(0), p, q);
  86. r1 -= A/4;
  87. r2 -= A/4;
  88. r3 -= A/4;
  89. roots[0] = r1;
  90. roots[1] = r2;
  91. roots[2] = r3;
  92. roots[3] = -A/4;
  93. return detail::polish_and_sort(a, b, c, d, e, roots);
  94. }
  95. // Biquadratic case:
  96. if (abs(q) <= (std::numeric_limits<Real>::min)()) {
  97. auto [r1, r2] = quadratic_roots(Real(1), p, r);
  98. if (r1 >= 0) {
  99. Real rtr = sqrt(r1);
  100. roots[0] = rtr - A/4;
  101. roots[1] = -rtr - A/4;
  102. }
  103. if (r2 >= 0) {
  104. Real rtr = sqrt(r2);
  105. roots[2] = rtr - A/4;
  106. roots[3] = -rtr - A/4;
  107. }
  108. return detail::polish_and_sort(a, b, c, d, e, roots);
  109. }
  110. // Now split the depressed quartic into two quadratics:
  111. // y^4 + py^2 + qy + r = (y^2 + sy + u)(y^2 - sy + v) = y^4 + (v+u-s^2)y^2 + s(v - u)y + uv
  112. // So p = v+u-s^2, q = s(v - u), r = uv.
  113. // Then (v+u)^2 - (v-u)^2 = 4uv = 4r = (p+s^2)^2 - q^2/s^2.
  114. // Multiply through by s^2 to get s^2(p+s^2)^2 - q^2 - 4rs^2 = 0, which is a cubic in s^2.
  115. // Then we let z = s^2, to get
  116. // z^3 + 2pz^2 + (p^2 - 4r)z - q^2 = 0.
  117. auto z_roots = cubic_roots(Real(1), 2*p, p*p - 4*r, -q*q);
  118. // z = s^2, so s = sqrt(z).
  119. // Hence we require a root > 0, and for the sake of sanity we should take the largest one:
  120. Real largest_root = std::numeric_limits<Real>::lowest();
  121. for (auto z : z_roots) {
  122. if (z > largest_root) {
  123. largest_root = z;
  124. }
  125. }
  126. // No real roots:
  127. if (largest_root <= 0) {
  128. return roots;
  129. }
  130. Real s = sqrt(largest_root);
  131. // s is nonzero, because we took care of the biquadratic case.
  132. Real v = (p + largest_root + q/s)/2;
  133. Real u = v - q/s;
  134. // Now solve y^2 + sy + u = 0:
  135. auto [root0, root1] = quadratic_roots(Real(1), s, u);
  136. // Now solve y^2 - sy + v = 0:
  137. auto [root2, root3] = quadratic_roots(Real(1), -s, v);
  138. roots[0] = root0;
  139. roots[1] = root1;
  140. roots[2] = root2;
  141. roots[3] = root3;
  142. for (auto& r : roots) {
  143. r -= A/4;
  144. }
  145. return detail::polish_and_sort(a, b, c, d, e, roots);
  146. }
  147. }
  148. #endif