diff options
Diffstat (limited to 'aberth/host_aberth.cpp')
-rw-r--r-- | aberth/host_aberth.cpp | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/aberth/host_aberth.cpp b/aberth/host_aberth.cpp new file mode 100644 index 0000000..c567f30 --- /dev/null +++ b/aberth/host_aberth.cpp @@ -0,0 +1,96 @@ +#include <random> +#include <cmath> +#include "host_aberth.h" + +using namespace std; + + +struct AberthState { + // boundPoly is 's' in the stop condition formulated at p.189-190 of + // https://link.springer.com/article/10.1007%2FBF02207694 + + const Poly &poly; + Poly deriv; + Poly boundPoly; + AApprox approx; + double radius; + + AberthState(const Poly &poly); + void regenerate(); + bool step(); + void iterate(); +}; + +static thread_local minstd_rand randgenerator = minstd_rand(random_device()()); + +AberthState::AberthState(const Poly &poly) + : poly(poly), deriv(derivative(poly)), radius(maxRootNorm(poly)) { + + regenerate(); + for (int i = 0; i <= N; i++) { + boundPoly[i] = abs(poly[i]) * (4 * i + 1); + } +} + +void AberthState::regenerate() { + auto genCoord = [this]() { + return uniform_real_distribution<double>(-radius, radius)(randgenerator); + }; + + for (int i = 0; i < N; i++) { + approx[i] = Com(genCoord(), genCoord()); + } +} + +// Lagrange-style step where the new elements are computed in parallel from the previous values +bool AberthState::step() { + array<Com, N * N> pairs; + for (int i = 0; i < N - 1; i++) { + for (int j = i + 1; j < N; j++) { + pairs[N * i + j] = 1.0 / (approx[i] - approx[j]); + } + } + + bool allConverged = true; + + AApprox newapprox; + AApprox offsets; + for (int i = 0; i < N; i++) { + Com pval = eval(poly, approx[i]); + Com derivval = eval(deriv, poly.size() - 1, approx[i]); + Com quo = pval / derivval; + Com sum = 0; + for (int j = 0; j < i; j++) sum -= pairs[N * j + i]; + for (int j = i + 1; j < N; j++) sum += pairs[N * i + j]; + offsets[i] = quo / (1.0 - quo * sum); + + newapprox[i] = approx[i] - offsets[i]; + + double sval = eval(boundPoly, abs(newapprox[i])); + if (abs(pval) > 1e-5 * sval) allConverged = false; + } + + approx = newapprox; + + return allConverged; +} + +void AberthState::iterate() { + int tries = 1, stepIdx = 1; + + while (!step()) { + stepIdx++; + + if (stepIdx > tries * 100) { + regenerate(); + stepIdx = 0; + tries++; + } + } +} + +AApprox aberth(const Poly &poly) { + AberthState state(poly); + state.iterate(); + return state.approx; +} |