aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--aberth/.gitignore1
-rw-r--r--aberth/Makefile20
-rw-r--r--aberth/aberth.cpp357
3 files changed, 378 insertions, 0 deletions
diff --git a/aberth/.gitignore b/aberth/.gitignore
new file mode 100644
index 0000000..60d8b01
--- /dev/null
+++ b/aberth/.gitignore
@@ -0,0 +1 @@
+aberth
diff --git a/aberth/Makefile b/aberth/Makefile
new file mode 100644
index 0000000..c5a67cf
--- /dev/null
+++ b/aberth/Makefile
@@ -0,0 +1,20 @@
+CXX = g++
+CXXFLAGS = -Wall -Wextra -std=c++17 -O3 -fwrapv -ffast-math -march=native -mtune=native
+LDFLAGS = -pthread
+
+.PHONY: all clean remake
+
+all: aberth
+
+clean:
+ rm -f aberth *.o
+
+remake: clean
+ $(MAKE) all
+
+
+aberth: $(patsubst %.cpp,%.o,$(wildcard *.cpp)) ../lodepng.o
+ $(CXX) -o $@ $^ $(LDFLAGS)
+
+%.o: %.cpp $(wildcard *.h)
+ $(CXX) $(CXXFLAGS) -c -o $@ $<
diff --git a/aberth/aberth.cpp b/aberth/aberth.cpp
new file mode 100644
index 0000000..2adf207
--- /dev/null
+++ b/aberth/aberth.cpp
@@ -0,0 +1,357 @@
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <vector>
+#include <array>
+#include <string>
+#include <complex>
+#include <utility>
+#include <tuple>
+#include <algorithm>
+#include <numeric>
+#include <thread>
+#include <mutex>
+#include <cstdlib>
+#include <cstdint>
+#include <cassert>
+#include "../lodepng.h"
+
+using namespace std;
+
+
+constexpr const int N = 18;
+
+using Com = complex<double>;
+
+using Poly = array<int, N + 1>;
+
+using AApprox = array<Com, N>;
+
+
+template <typename T>
+constexpr static T clearLowestBit(T value) {
+ return value & (value - 1);
+}
+
+template <typename T>
+constexpr static bool ispow2(T value) {
+ return clearLowestBit(value) == 0;
+}
+
+template <typename T>
+constexpr static T ceil2(T value) {
+ T value2 = clearLowestBit(value);
+ if (value2 == 0) return value;
+
+ while (true) {
+ value = value2;
+ value2 = clearLowestBit(value);
+ if (value2 == 0) return value << 1;
+ }
+}
+
+__attribute__((unused))
+static ostream& operator<<(ostream &os, const Poly &p) {
+ static const char *supers[10] = {
+ "⁰", "¹", "²", "³", "⁴", "⁵", "⁶", "⁷", "⁸", "⁹"
+ };
+ os << p[0];
+ for (int i = 1; i < (int)p.size(); i++) {
+ if (p[i] < 0) os << " - " << -p[i];
+ else if (p[i] > 0) os << " + " << p[i];
+ else continue;
+
+ os << "x";
+
+ if (i == 1) continue;
+
+ ostringstream ss;
+ ss << i;
+ string s = ss.str();
+ for (char c : s) os << supers[c - '0'];
+ }
+ return os;
+}
+
+template <typename T>
+static T eval(const Poly &p, int nterms, T pt) {
+ T value = p[nterms - 1];
+ for (int i = nterms - 2; i >= 0; i--) {
+ value = pt * value + (double)p[i];
+ }
+ return value;
+}
+
+template <typename T>
+static T eval(const Poly &p, T pt) {
+ return eval(p, p.size(), pt);
+}
+
+static Poly derivative(const Poly &p) {
+ Poly res;
+ for (int i = res.size() - 2; i >= 0; i--) {
+ res[i] = (i+1) * p[i+1];
+ }
+ return res;
+}
+
+static double maxRootNorm(const Poly &poly) {
+ // Cauchy's bound: https://en.wikipedia.org/wiki/Geometrical_properties_of_polynomial_roots#Lagrange's_and_Cauchy's_bounds
+
+ double value = 0;
+ double last = (double)poly.back();
+ for (int i = 0; i < (int)poly.size() - 1; i++) {
+ value = max(value, abs(poly[i] / last));
+ }
+ return 1 + value;
+}
+
+struct AberthState {
+ const Poly &poly;
+ Poly deriv;
+ Poly boundPoly;
+ AApprox approx;
+ double radius;
+
+ void regenerate() {
+ auto genCoord = [this]() { return (double)random() / INT_MAX * 2 * radius - radius; };
+ for (int i = 0; i < N; i++) {
+ approx[i] = Com(genCoord(), genCoord());
+ }
+ }
+
+ // boundPoly is 's' in the stop condition formulated at p.189-190 of
+ // https://link.springer.com/article/10.1007%2FBF02207694
+
+ AberthState(const Poly &poly)
+ : poly(poly), deriv(derivative(poly)), radius(maxRootNorm(poly)) {
+
+ regenerate();
+ for (int i = 0; i <= N; i++) {
+ boundPoly[i] = abs(poly[i]) * (4 * i + 1);
+ }
+ }
+
+ // Gauss-Seidel-style step where the updated values are already used in the current iteration
+ bool step() {
+ array<Com, N * N> pairs;
+ for (int i = 0; i < N - 1; i++) {
+ for (int j = i + 1; j < N; j++) {
+ pairs[N * i + j] = 1.0 / (approx[i] - approx[j]);
+ }
+ }
+
+ bool allConverged = true;
+
+ AApprox offsets;
+ for (int i = 0; i < N; i++) {
+ Com pval = eval(poly, approx[i]);
+ Com derivval = eval(deriv, poly.size() - 1, approx[i]);
+ Com quo = pval / derivval;
+ Com sum = 0;
+ for (int j = 0; j < i; j++) sum += pairs[N * j + i];
+ for (int j = i + 1; j < N; j++) sum += pairs[N * i + j];
+ offsets[i] = quo / (1.0 - quo * sum);
+
+ approx[i] -= offsets[i];
+
+ double sval = eval(boundPoly, abs(approx[i]));
+ if (abs(pval) > 1e-9 * sval) allConverged = false;
+ }
+
+ return allConverged;
+ }
+
+ void iterate() {
+ int tries = 1, stepIdx = 1;
+ while (!step()) {
+ stepIdx++;
+
+ if (stepIdx > tries * 100) {
+ regenerate();
+ stepIdx = 0;
+ tries++;
+ }
+ }
+ }
+};
+
+static AApprox aberth(const Poly &poly) {
+ AberthState state(poly);
+ state.iterate();
+ return state.approx;
+}
+
+// Set the constant coefficient to 1; nextDerbyshire will never change it
+static Poly initDerbyshire() {
+ Poly poly;
+ poly[0] = 1;
+ fill(poly.begin() + 1, poly.end(), -1);
+ return poly;
+}
+
+// Returns whether we just looped around
+static bool nextDerbyshire(Poly &poly) {
+ for (int i = 1; i < (int)poly.size(); i++) {
+ if (poly[i] == -1) {
+ poly[i] = 1;
+ return false;
+ }
+ poly[i] = -1;
+ }
+ return true;
+}
+
+static Poly derbyshireAtIndex(int index) {
+ Poly poly;
+ poly[0] = 1;
+ for (int i = 1; i <= N; i++) {
+ poly[i] = index & 1 ? 1 : -1;
+ index >>= 1;
+ }
+ assert(index == 0);
+ return poly;
+}
+
+struct Job {
+ Poly init;
+ int numItems;
+};
+
+static vector<Job> derbyshireJobs(int targetJobs) {
+ int njobs = min(1 << N, ceil2(targetJobs));
+ int jobsize = (1 << N) / njobs;
+
+ vector<Job> jobs(njobs);
+ for (int i = 0; i < njobs; i++) {
+ jobs[i].init = derbyshireAtIndex(i * jobsize);
+ jobs[i].numItems = jobsize;
+ }
+
+ return jobs;
+}
+
+static tuple<int, int, vector<int>> computeCounts() {
+ constexpr const int W = 900;
+ constexpr const int H = 900;
+ constexpr const int numThreads = 4;
+ static_assert(ispow2(numThreads));
+ constexpr const Com bottomLeft = Com(-1.5, -1.5);
+ constexpr const Com topRight = Com(1.5, 1.5);
+
+ vector<int> counts(W * H);
+ mutex countsMutex;
+
+ vector<Job> jobs = derbyshireJobs(numThreads);
+ assert(jobs.size() == numThreads);
+
+ vector<thread> threads(jobs.size());
+ for (int i = 0; i < (int)jobs.size(); i++) {
+ threads[i] = thread([&counts, &countsMutex, job = jobs[i], bottomLeft, topRight]() {
+ auto calcIndex = [](double value, double left, double right, int steps) -> int {
+ return (value - left) / (right - left) * (steps - 1) + 0.5;
+ };
+ auto calcPos = [bottomLeft, topRight, &calcIndex](Com z) -> pair<int, int> {
+ return make_pair(
+ calcIndex(z.real(), bottomLeft.real(), topRight.real(), W),
+ calcIndex(z.imag(), bottomLeft.imag(), topRight.imag(), H)
+ );
+ };
+
+ vector<int> localCounts(W * H);
+
+ Poly poly = job.init;
+ for (int i = 0; i < job.numItems; i++) {
+ for (Com z : aberth(poly)) {
+ int x, y;
+ tie(x, y) = calcPos(z);
+ if (0 <= x && x < W && 0 <= y && y < H) {
+ localCounts[W * y + x]++;
+ }
+ }
+ nextDerbyshire(poly);
+ }
+
+ lock_guard guard(countsMutex);
+ for (int i = 0; i < W * H; i++) counts[i] += localCounts[i];
+ });
+ }
+
+ for (thread &th : threads) th.join();
+
+ return make_tuple(W, H, counts);
+}
+
+static void writeCounts(int W, int H, const vector<int> &counts, const char *fname) {
+ ofstream f(fname);
+ f << W << ' ' << H << '\n';
+ for (int y = 0; y < H; y++) {
+ for (int x = 0; x < W; x++) {
+ if (x != 0) f << ' ';
+ f << counts[W * y + x];
+ }
+ f << '\n';
+ }
+}
+
+static tuple<int, int, vector<int>> readCounts(const char *fname) {
+ ifstream f(fname);
+ int W, H;
+ f >> W >> H;
+ vector<int> counts(W * H);
+ for (int &v : counts) f >> v;
+ return make_tuple(W, H, counts);
+}
+
+static int rankCounts(vector<int> &counts) {
+ int maxcount = reduce(counts.begin(), counts.end(), 0, [](int a, int b) -> int { return max(a, b); });
+
+ vector<int> cumul(maxcount + 1, 0);
+ for (int v : counts) cumul[v]++;
+ cumul[0] = 0;
+ for (int i = 1; i < (int)cumul.size(); i++) cumul[i] += cumul[i-1];
+ // assert(cumul[maxcount + 1] == (int)counts.size());
+
+ for (int &v : counts) v = cumul[v];
+
+ return cumul[maxcount];
+}
+
+static vector<uint8_t> drawImage(int W, int H, const vector<int> &counts, int maxcount) {
+ vector<uint8_t> image(3 * W * H);
+
+ for (int y = 0; y < H; y++) {
+ for (int x = 0; x < W; x++) {
+ double value = (double)counts[W * y + x] / maxcount * 255;
+ image[3 * (W * y + x) + 0] = value;
+ image[3 * (W * y + x) + 1] = value;
+ image[3 * (W * y + x) + 2] = value;
+ }
+ }
+
+ return image;
+}
+
+int main(int argc, char **argv) {
+ srandomdev();
+
+ int W, H;
+ vector<int> counts;
+
+ if (argc <= 1) {
+ tie(W, H, counts) = computeCounts();
+ writeCounts(W, H, counts, "out.txt");
+ } else if (argc == 2) {
+ tie(W, H, counts) = readCounts(argv[1]);
+ } else {
+ cerr << "Usage: " << argv[0] << " -- compute and draw" << endl;
+ cerr << "Usage: " << argv[0] << " <out.txt> -- draw already-computed data" << endl;
+ return 1;
+ }
+
+ int maxcount = rankCounts(counts);
+
+ vector<uint8_t> image = drawImage(W, H, counts, maxcount);
+
+ assert(lodepng_encode24_file("out.png", image.data(), W, H) == 0);
+}