diff options
Diffstat (limited to 'aberth/main.cpp')
-rw-r--r-- | aberth/main.cpp | 224 |
1 files changed, 6 insertions, 218 deletions
diff --git a/aberth/main.cpp b/aberth/main.cpp index 1cdea3a..152640f 100644 --- a/aberth/main.cpp +++ b/aberth/main.cpp @@ -1,232 +1,20 @@ #include <iostream> #include <fstream> -#include <sstream> -#include <vector> -#include <array> #include <string> -#include <complex> -#include <random> -#include <utility> -#include <tuple> +#include <vector> #include <algorithm> -#include <thread> -#include <mutex> -#include <type_traits> +#include <tuple> #include <cstdint> +#include <cstring> #include <cassert> #include "../lodepng.h" +#include "compute_host.h" #include "defs.h" #include "kernel.h" -#include "util.h" using namespace std; -template <typename T> -static T eval(const Poly &p, int nterms, T pt) { - T value = p[nterms - 1]; - for (int i = nterms - 2; i >= 0; i--) { - value = pt * value + (double)p[i]; - } - return value; -} - -template <typename T> -static T eval(const Poly &p, T pt) { - return eval(p, p.size(), pt); -} - -static Poly derivative(const Poly &p) { - Poly res; - for (int i = res.size() - 2; i >= 0; i--) { - res[i] = (i+1) * p[i+1]; - } - return res; -} - -static double maxRootNorm(const Poly &poly) { - // Cauchy's bound: https://en.wikipedia.org/wiki/Geometrical_properties_of_polynomial_roots#Lagrange's_and_Cauchy's_bounds - - double value = 0; - double last = (double)poly.back(); - for (int i = 0; i < (int)poly.size() - 1; i++) { - value = max(value, abs(poly[i] / last)); - } - return 1 + value; -} - -static thread_local minstd_rand randgenerator = minstd_rand(random_device()()); - -struct AberthState { - const Poly &poly; - Poly deriv; - Poly boundPoly; - AApprox approx; - double radius; - - void regenerate() { - auto genCoord = [this]() { - return uniform_real_distribution<double>(-radius, radius)(randgenerator); - }; - for (int i = 0; i < N; i++) { - approx[i] = Com(genCoord(), genCoord()); - } - } - - // boundPoly is 's' in the stop condition formulated at p.189-190 of - // https://link.springer.com/article/10.1007%2FBF02207694 - - AberthState(const Poly &poly) - : poly(poly), deriv(derivative(poly)), radius(maxRootNorm(poly)) { - - regenerate(); - for (int i = 0; i <= N; i++) { - boundPoly[i] = abs(poly[i]) * (4 * i + 1); - } - } - - // Lagrange-style step where the new elements are computed in parallel from the previous values - bool step() { - array<Com, N * N> pairs; - for (int i = 0; i < N - 1; i++) { - for (int j = i + 1; j < N; j++) { - pairs[N * i + j] = 1.0 / (approx[i] - approx[j]); - } - } - - bool allConverged = true; - - AApprox newapprox; - AApprox offsets; - for (int i = 0; i < N; i++) { - Com pval = eval(poly, approx[i]); - Com derivval = eval(deriv, poly.size() - 1, approx[i]); - Com quo = pval / derivval; - Com sum = 0; - for (int j = 0; j < i; j++) sum -= pairs[N * j + i]; - for (int j = i + 1; j < N; j++) sum += pairs[N * i + j]; - offsets[i] = quo / (1.0 - quo * sum); - - // approx[i] -= offsets[i]; - newapprox[i] = approx[i] - offsets[i]; - - double sval = eval(boundPoly, abs(newapprox[i])); - if (abs(pval) > 1e-5 * sval) allConverged = false; - } - - approx = newapprox; - - return allConverged; - } - - void iterate() { - int tries = 1, stepIdx = 1; - while (!step()) { - stepIdx++; - - if (stepIdx > tries * 100) { - regenerate(); - stepIdx = 0; - tries++; - } - } - } -}; - -static AApprox aberth(const Poly &poly) { - AberthState state(poly); - state.iterate(); - return state.approx; -} - -// Returns whether we just looped around -static bool nextDerbyshire(Poly &poly) { - for (int i = 1; i < (int)poly.size(); i++) { - if (poly[i] == -1) { - poly[i] = 1; - return false; - } - poly[i] = -1; - } - return true; -} - -static Poly derbyshireAtIndex(int index) { - Poly poly; - poly[0] = 1; - for (int i = 1; i <= N; i++) { - poly[i] = index & 1 ? 1 : -1; - index >>= 1; - } - assert(index == 0); - return poly; -} - -struct Job { - Poly init; - int numItems; -}; - -static vector<Job> derbyshireJobs(int targetJobs) { - int njobs = min(1 << N, ceil2(targetJobs)); - int jobsize = (1 << N) / njobs; - - vector<Job> jobs(njobs); - for (int i = 0; i < njobs; i++) { - jobs[i].init = derbyshireAtIndex(i * jobsize); - jobs[i].numItems = jobsize; - } - - return jobs; -} - -static vector<int> computeCounts(int W, int H, Com bottomLeft, Com topRight) { - constexpr const int numThreads = 4; - static_assert(ispow2(numThreads)); - - vector<int> counts(W * H); - mutex countsMutex; - - vector<Job> jobs = derbyshireJobs(numThreads); - assert(jobs.size() == numThreads); - - vector<thread> threads(jobs.size()); - for (int i = 0; i < (int)jobs.size(); i++) { - threads[i] = thread([W, H, &counts, &countsMutex, job = jobs[i], bottomLeft, topRight]() { - auto calcIndex = [](double value, double left, double right, int steps) -> int { - return (value - left) / (right - left) * (steps - 1) + 0.5; - }; - auto calcPos = [W, H, bottomLeft, topRight, &calcIndex](Com z) -> pair<int, int> { - return make_pair( - calcIndex(z.real(), bottomLeft.real(), topRight.real(), W), - calcIndex(z.imag(), bottomLeft.imag(), topRight.imag(), H) - ); - }; - - vector<int> localCounts(W * H); - - Poly poly = job.init; - for (int i = 0; i < job.numItems; i++) { - for (Com z : aberth(poly)) { - int x, y; - tie(x, y) = calcPos(z); - if (0 <= x && x < W && 0 <= y && y < H) { - localCounts[W * y + x]++; - } - } - nextDerbyshire(poly); - } - - lock_guard<mutex> guard(countsMutex); - for (int i = 0; i < W * H; i++) counts[i] += localCounts[i]; - }); - } - - for (thread &th : threads) th.join(); - - return counts; -} - static void writeCounts(int W, int H, const vector<int> &counts, const char *fname) { ofstream f(fname); f << W << ' ' << H << '\n'; @@ -289,12 +77,12 @@ int main(int argc, char **argv) { const Com bottomLeft = Com(-1.5, -1.5); const Com topRight = Com(1.5, 1.5); - // counts = computeCounts(W, H, bottomLeft, topRight); + // counts = computeHost(W, H, bottomLeft, topRight); Kernel().run_chunked(counts, W, H, bottomLeft, topRight, 42, 1 << 14); // Kernel().run_all(counts, W, H, bottomLeft, topRight, 42); writeCounts(W, H, counts, "out.txt"); - } else if (argc == 2) { + } else if (argc == 2 && strcmp(argv[1], "-h") != 0) { tie(W, H, counts) = readCounts(argv[1]); } else { cerr << "Usage: " << argv[0] << " -- compute and draw" << endl; |