aboutsummaryrefslogtreecommitdiff
path: root/aberth/host_aberth.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'aberth/host_aberth.cpp')
-rw-r--r--aberth/host_aberth.cpp96
1 files changed, 96 insertions, 0 deletions
diff --git a/aberth/host_aberth.cpp b/aberth/host_aberth.cpp
new file mode 100644
index 0000000..c567f30
--- /dev/null
+++ b/aberth/host_aberth.cpp
@@ -0,0 +1,96 @@
+#include <random>
+#include <cmath>
+#include "host_aberth.h"
+
+using namespace std;
+
+
+struct AberthState {
+ // boundPoly is 's' in the stop condition formulated at p.189-190 of
+ // https://link.springer.com/article/10.1007%2FBF02207694
+
+ const Poly &poly;
+ Poly deriv;
+ Poly boundPoly;
+ AApprox approx;
+ double radius;
+
+ AberthState(const Poly &poly);
+ void regenerate();
+ bool step();
+ void iterate();
+};
+
+static thread_local minstd_rand randgenerator = minstd_rand(random_device()());
+
+AberthState::AberthState(const Poly &poly)
+ : poly(poly), deriv(derivative(poly)), radius(maxRootNorm(poly)) {
+
+ regenerate();
+ for (int i = 0; i <= N; i++) {
+ boundPoly[i] = abs(poly[i]) * (4 * i + 1);
+ }
+}
+
+void AberthState::regenerate() {
+ auto genCoord = [this]() {
+ return uniform_real_distribution<double>(-radius, radius)(randgenerator);
+ };
+
+ for (int i = 0; i < N; i++) {
+ approx[i] = Com(genCoord(), genCoord());
+ }
+}
+
+// Lagrange-style step where the new elements are computed in parallel from the previous values
+bool AberthState::step() {
+ array<Com, N * N> pairs;
+ for (int i = 0; i < N - 1; i++) {
+ for (int j = i + 1; j < N; j++) {
+ pairs[N * i + j] = 1.0 / (approx[i] - approx[j]);
+ }
+ }
+
+ bool allConverged = true;
+
+ AApprox newapprox;
+ AApprox offsets;
+ for (int i = 0; i < N; i++) {
+ Com pval = eval(poly, approx[i]);
+ Com derivval = eval(deriv, poly.size() - 1, approx[i]);
+ Com quo = pval / derivval;
+ Com sum = 0;
+ for (int j = 0; j < i; j++) sum -= pairs[N * j + i];
+ for (int j = i + 1; j < N; j++) sum += pairs[N * i + j];
+ offsets[i] = quo / (1.0 - quo * sum);
+
+ newapprox[i] = approx[i] - offsets[i];
+
+ double sval = eval(boundPoly, abs(newapprox[i]));
+ if (abs(pval) > 1e-5 * sval) allConverged = false;
+ }
+
+ approx = newapprox;
+
+ return allConverged;
+}
+
+void AberthState::iterate() {
+ int tries = 1, stepIdx = 1;
+
+ while (!step()) {
+ stepIdx++;
+
+ if (stepIdx > tries * 100) {
+ regenerate();
+ stepIdx = 0;
+ tries++;
+ }
+ }
+}
+
+AApprox aberth(const Poly &poly) {
+ AberthState state(poly);
+ state.iterate();
+ return state.approx;
+}