aboutsummaryrefslogtreecommitdiff
path: root/aberth/host_aberth.cpp
blob: fcffebfc245bd99d9fa038b59b498d62f76c7c1d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <random>
#include <cmath>
#include "host_aberth.h"

using namespace std;


struct AberthState {
	// boundPoly is 's' in the stop condition formulated at p.189-190 of
	// https://link.springer.com/article/10.1007%2FBF02207694

	const Poly &poly;
	Poly deriv;
	Poly boundPoly;
	AApprox approx;
	double radius;

	AberthState(const Poly &poly);
	void regenerate();
	bool step();
	void iterate();
};

static thread_local minstd_rand randgenerator = minstd_rand(random_device()());

AberthState::AberthState(const Poly &poly)
        : poly(poly), deriv(derivative(poly)), radius(maxRootNorm(poly)) {

    regenerate();
    for (int i = 0; i <= N; i++) {
        boundPoly[i] = abs(poly[i]) * (4 * i + 1);
    }
}

void AberthState::regenerate() {
    auto genCoord = [this]() {
        return uniform_real_distribution<double>(-radius, radius)(randgenerator);
    };

    for (int i = 0; i < N; i++) {
        approx[i] = Com(genCoord(), genCoord());
    }
}

// Jacobi-style step where the new elements are computed in parallel from the previous values
bool AberthState::step() {
    array<Com, N * N> pairs;
    for (int i = 0; i < N - 1; i++) {
        for (int j = i + 1; j < N; j++) {
            pairs[N * i + j] = 1.0 / (approx[i] - approx[j]);
        }
    }

    bool allConverged = true;

    AApprox newapprox;
    AApprox offsets;
    for (int i = 0; i < N; i++) {
        Com pval = eval(poly, approx[i]);
        Com derivval = eval(deriv, poly.size() - 1, approx[i]);
        Com quo = pval / derivval;
        Com sum = 0;
        for (int j = 0; j < i; j++) sum -= pairs[N * j + i];
        for (int j = i + 1; j < N; j++) sum += pairs[N * i + j];
        offsets[i] = quo / (1.0 - quo * sum);

        newapprox[i] = approx[i] - offsets[i];

        double sval = eval(boundPoly, abs(newapprox[i]));
        if (abs(pval) > 1e-5 * sval) allConverged = false;
    }

    approx = newapprox;

    return allConverged;
}

void AberthState::iterate() {
    int tries = 1, stepIdx = 1;

    while (!step()) {
        stepIdx++;

        if (stepIdx > tries * 100) {
            regenerate();
            stepIdx = 0;
            tries++;
        }
    }
}

AApprox aberth(const Poly &poly) {
	AberthState state(poly);
	state.iterate();
	return state.approx;
}