aboutsummaryrefslogtreecommitdiff
path: root/aberth/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'aberth/main.cpp')
-rw-r--r--aberth/main.cpp224
1 files changed, 6 insertions, 218 deletions
diff --git a/aberth/main.cpp b/aberth/main.cpp
index 1cdea3a..152640f 100644
--- a/aberth/main.cpp
+++ b/aberth/main.cpp
@@ -1,232 +1,20 @@
#include <iostream>
#include <fstream>
-#include <sstream>
-#include <vector>
-#include <array>
#include <string>
-#include <complex>
-#include <random>
-#include <utility>
-#include <tuple>
+#include <vector>
#include <algorithm>
-#include <thread>
-#include <mutex>
-#include <type_traits>
+#include <tuple>
#include <cstdint>
+#include <cstring>
#include <cassert>
#include "../lodepng.h"
+#include "compute_host.h"
#include "defs.h"
#include "kernel.h"
-#include "util.h"
using namespace std;
-template <typename T>
-static T eval(const Poly &p, int nterms, T pt) {
- T value = p[nterms - 1];
- for (int i = nterms - 2; i >= 0; i--) {
- value = pt * value + (double)p[i];
- }
- return value;
-}
-
-template <typename T>
-static T eval(const Poly &p, T pt) {
- return eval(p, p.size(), pt);
-}
-
-static Poly derivative(const Poly &p) {
- Poly res;
- for (int i = res.size() - 2; i >= 0; i--) {
- res[i] = (i+1) * p[i+1];
- }
- return res;
-}
-
-static double maxRootNorm(const Poly &poly) {
- // Cauchy's bound: https://en.wikipedia.org/wiki/Geometrical_properties_of_polynomial_roots#Lagrange's_and_Cauchy's_bounds
-
- double value = 0;
- double last = (double)poly.back();
- for (int i = 0; i < (int)poly.size() - 1; i++) {
- value = max(value, abs(poly[i] / last));
- }
- return 1 + value;
-}
-
-static thread_local minstd_rand randgenerator = minstd_rand(random_device()());
-
-struct AberthState {
- const Poly &poly;
- Poly deriv;
- Poly boundPoly;
- AApprox approx;
- double radius;
-
- void regenerate() {
- auto genCoord = [this]() {
- return uniform_real_distribution<double>(-radius, radius)(randgenerator);
- };
- for (int i = 0; i < N; i++) {
- approx[i] = Com(genCoord(), genCoord());
- }
- }
-
- // boundPoly is 's' in the stop condition formulated at p.189-190 of
- // https://link.springer.com/article/10.1007%2FBF02207694
-
- AberthState(const Poly &poly)
- : poly(poly), deriv(derivative(poly)), radius(maxRootNorm(poly)) {
-
- regenerate();
- for (int i = 0; i <= N; i++) {
- boundPoly[i] = abs(poly[i]) * (4 * i + 1);
- }
- }
-
- // Lagrange-style step where the new elements are computed in parallel from the previous values
- bool step() {
- array<Com, N * N> pairs;
- for (int i = 0; i < N - 1; i++) {
- for (int j = i + 1; j < N; j++) {
- pairs[N * i + j] = 1.0 / (approx[i] - approx[j]);
- }
- }
-
- bool allConverged = true;
-
- AApprox newapprox;
- AApprox offsets;
- for (int i = 0; i < N; i++) {
- Com pval = eval(poly, approx[i]);
- Com derivval = eval(deriv, poly.size() - 1, approx[i]);
- Com quo = pval / derivval;
- Com sum = 0;
- for (int j = 0; j < i; j++) sum -= pairs[N * j + i];
- for (int j = i + 1; j < N; j++) sum += pairs[N * i + j];
- offsets[i] = quo / (1.0 - quo * sum);
-
- // approx[i] -= offsets[i];
- newapprox[i] = approx[i] - offsets[i];
-
- double sval = eval(boundPoly, abs(newapprox[i]));
- if (abs(pval) > 1e-5 * sval) allConverged = false;
- }
-
- approx = newapprox;
-
- return allConverged;
- }
-
- void iterate() {
- int tries = 1, stepIdx = 1;
- while (!step()) {
- stepIdx++;
-
- if (stepIdx > tries * 100) {
- regenerate();
- stepIdx = 0;
- tries++;
- }
- }
- }
-};
-
-static AApprox aberth(const Poly &poly) {
- AberthState state(poly);
- state.iterate();
- return state.approx;
-}
-
-// Returns whether we just looped around
-static bool nextDerbyshire(Poly &poly) {
- for (int i = 1; i < (int)poly.size(); i++) {
- if (poly[i] == -1) {
- poly[i] = 1;
- return false;
- }
- poly[i] = -1;
- }
- return true;
-}
-
-static Poly derbyshireAtIndex(int index) {
- Poly poly;
- poly[0] = 1;
- for (int i = 1; i <= N; i++) {
- poly[i] = index & 1 ? 1 : -1;
- index >>= 1;
- }
- assert(index == 0);
- return poly;
-}
-
-struct Job {
- Poly init;
- int numItems;
-};
-
-static vector<Job> derbyshireJobs(int targetJobs) {
- int njobs = min(1 << N, ceil2(targetJobs));
- int jobsize = (1 << N) / njobs;
-
- vector<Job> jobs(njobs);
- for (int i = 0; i < njobs; i++) {
- jobs[i].init = derbyshireAtIndex(i * jobsize);
- jobs[i].numItems = jobsize;
- }
-
- return jobs;
-}
-
-static vector<int> computeCounts(int W, int H, Com bottomLeft, Com topRight) {
- constexpr const int numThreads = 4;
- static_assert(ispow2(numThreads));
-
- vector<int> counts(W * H);
- mutex countsMutex;
-
- vector<Job> jobs = derbyshireJobs(numThreads);
- assert(jobs.size() == numThreads);
-
- vector<thread> threads(jobs.size());
- for (int i = 0; i < (int)jobs.size(); i++) {
- threads[i] = thread([W, H, &counts, &countsMutex, job = jobs[i], bottomLeft, topRight]() {
- auto calcIndex = [](double value, double left, double right, int steps) -> int {
- return (value - left) / (right - left) * (steps - 1) + 0.5;
- };
- auto calcPos = [W, H, bottomLeft, topRight, &calcIndex](Com z) -> pair<int, int> {
- return make_pair(
- calcIndex(z.real(), bottomLeft.real(), topRight.real(), W),
- calcIndex(z.imag(), bottomLeft.imag(), topRight.imag(), H)
- );
- };
-
- vector<int> localCounts(W * H);
-
- Poly poly = job.init;
- for (int i = 0; i < job.numItems; i++) {
- for (Com z : aberth(poly)) {
- int x, y;
- tie(x, y) = calcPos(z);
- if (0 <= x && x < W && 0 <= y && y < H) {
- localCounts[W * y + x]++;
- }
- }
- nextDerbyshire(poly);
- }
-
- lock_guard<mutex> guard(countsMutex);
- for (int i = 0; i < W * H; i++) counts[i] += localCounts[i];
- });
- }
-
- for (thread &th : threads) th.join();
-
- return counts;
-}
-
static void writeCounts(int W, int H, const vector<int> &counts, const char *fname) {
ofstream f(fname);
f << W << ' ' << H << '\n';
@@ -289,12 +77,12 @@ int main(int argc, char **argv) {
const Com bottomLeft = Com(-1.5, -1.5);
const Com topRight = Com(1.5, 1.5);
- // counts = computeCounts(W, H, bottomLeft, topRight);
+ // counts = computeHost(W, H, bottomLeft, topRight);
Kernel().run_chunked(counts, W, H, bottomLeft, topRight, 42, 1 << 14);
// Kernel().run_all(counts, W, H, bottomLeft, topRight, 42);
writeCounts(W, H, counts, "out.txt");
- } else if (argc == 2) {
+ } else if (argc == 2 && strcmp(argv[1], "-h") != 0) {
tie(W, H, counts) = readCounts(argv[1]);
} else {
cerr << "Usage: " << argv[0] << " -- compute and draw" << endl;