diff options
-rw-r--r-- | aberth/.gitignore | 1 | ||||
-rw-r--r-- | aberth/Makefile | 20 | ||||
-rw-r--r-- | aberth/aberth.cpp | 357 |
3 files changed, 378 insertions, 0 deletions
diff --git a/aberth/.gitignore b/aberth/.gitignore new file mode 100644 index 0000000..60d8b01 --- /dev/null +++ b/aberth/.gitignore @@ -0,0 +1 @@ +aberth diff --git a/aberth/Makefile b/aberth/Makefile new file mode 100644 index 0000000..c5a67cf --- /dev/null +++ b/aberth/Makefile @@ -0,0 +1,20 @@ +CXX = g++ +CXXFLAGS = -Wall -Wextra -std=c++17 -O3 -fwrapv -ffast-math -march=native -mtune=native +LDFLAGS = -pthread + +.PHONY: all clean remake + +all: aberth + +clean: + rm -f aberth *.o + +remake: clean + $(MAKE) all + + +aberth: $(patsubst %.cpp,%.o,$(wildcard *.cpp)) ../lodepng.o + $(CXX) -o $@ $^ $(LDFLAGS) + +%.o: %.cpp $(wildcard *.h) + $(CXX) $(CXXFLAGS) -c -o $@ $< diff --git a/aberth/aberth.cpp b/aberth/aberth.cpp new file mode 100644 index 0000000..2adf207 --- /dev/null +++ b/aberth/aberth.cpp @@ -0,0 +1,357 @@ +#include <iostream> +#include <fstream> +#include <sstream> +#include <vector> +#include <array> +#include <string> +#include <complex> +#include <utility> +#include <tuple> +#include <algorithm> +#include <numeric> +#include <thread> +#include <mutex> +#include <cstdlib> +#include <cstdint> +#include <cassert> +#include "../lodepng.h" + +using namespace std; + + +constexpr const int N = 18; + +using Com = complex<double>; + +using Poly = array<int, N + 1>; + +using AApprox = array<Com, N>; + + +template <typename T> +constexpr static T clearLowestBit(T value) { + return value & (value - 1); +} + +template <typename T> +constexpr static bool ispow2(T value) { + return clearLowestBit(value) == 0; +} + +template <typename T> +constexpr static T ceil2(T value) { + T value2 = clearLowestBit(value); + if (value2 == 0) return value; + + while (true) { + value = value2; + value2 = clearLowestBit(value); + if (value2 == 0) return value << 1; + } +} + +__attribute__((unused)) +static ostream& operator<<(ostream &os, const Poly &p) { + static const char *supers[10] = { + "⁰", "¹", "²", "³", "⁴", "⁵", "⁶", "⁷", "⁸", "⁹" + }; + os << p[0]; + for (int i = 1; i < (int)p.size(); i++) { + if (p[i] < 0) os << " - " << -p[i]; + else if (p[i] > 0) os << " + " << p[i]; + else continue; + + os << "x"; + + if (i == 1) continue; + + ostringstream ss; + ss << i; + string s = ss.str(); + for (char c : s) os << supers[c - '0']; + } + return os; +} + +template <typename T> +static T eval(const Poly &p, int nterms, T pt) { + T value = p[nterms - 1]; + for (int i = nterms - 2; i >= 0; i--) { + value = pt * value + (double)p[i]; + } + return value; +} + +template <typename T> +static T eval(const Poly &p, T pt) { + return eval(p, p.size(), pt); +} + +static Poly derivative(const Poly &p) { + Poly res; + for (int i = res.size() - 2; i >= 0; i--) { + res[i] = (i+1) * p[i+1]; + } + return res; +} + +static double maxRootNorm(const Poly &poly) { + // Cauchy's bound: https://en.wikipedia.org/wiki/Geometrical_properties_of_polynomial_roots#Lagrange's_and_Cauchy's_bounds + + double value = 0; + double last = (double)poly.back(); + for (int i = 0; i < (int)poly.size() - 1; i++) { + value = max(value, abs(poly[i] / last)); + } + return 1 + value; +} + +struct AberthState { + const Poly &poly; + Poly deriv; + Poly boundPoly; + AApprox approx; + double radius; + + void regenerate() { + auto genCoord = [this]() { return (double)random() / INT_MAX * 2 * radius - radius; }; + for (int i = 0; i < N; i++) { + approx[i] = Com(genCoord(), genCoord()); + } + } + + // boundPoly is 's' in the stop condition formulated at p.189-190 of + // https://link.springer.com/article/10.1007%2FBF02207694 + + AberthState(const Poly &poly) + : poly(poly), deriv(derivative(poly)), radius(maxRootNorm(poly)) { + + regenerate(); + for (int i = 0; i <= N; i++) { + boundPoly[i] = abs(poly[i]) * (4 * i + 1); + } + } + + // Gauss-Seidel-style step where the updated values are already used in the current iteration + bool step() { + array<Com, N * N> pairs; + for (int i = 0; i < N - 1; i++) { + for (int j = i + 1; j < N; j++) { + pairs[N * i + j] = 1.0 / (approx[i] - approx[j]); + } + } + + bool allConverged = true; + + AApprox offsets; + for (int i = 0; i < N; i++) { + Com pval = eval(poly, approx[i]); + Com derivval = eval(deriv, poly.size() - 1, approx[i]); + Com quo = pval / derivval; + Com sum = 0; + for (int j = 0; j < i; j++) sum += pairs[N * j + i]; + for (int j = i + 1; j < N; j++) sum += pairs[N * i + j]; + offsets[i] = quo / (1.0 - quo * sum); + + approx[i] -= offsets[i]; + + double sval = eval(boundPoly, abs(approx[i])); + if (abs(pval) > 1e-9 * sval) allConverged = false; + } + + return allConverged; + } + + void iterate() { + int tries = 1, stepIdx = 1; + while (!step()) { + stepIdx++; + + if (stepIdx > tries * 100) { + regenerate(); + stepIdx = 0; + tries++; + } + } + } +}; + +static AApprox aberth(const Poly &poly) { + AberthState state(poly); + state.iterate(); + return state.approx; +} + +// Set the constant coefficient to 1; nextDerbyshire will never change it +static Poly initDerbyshire() { + Poly poly; + poly[0] = 1; + fill(poly.begin() + 1, poly.end(), -1); + return poly; +} + +// Returns whether we just looped around +static bool nextDerbyshire(Poly &poly) { + for (int i = 1; i < (int)poly.size(); i++) { + if (poly[i] == -1) { + poly[i] = 1; + return false; + } + poly[i] = -1; + } + return true; +} + +static Poly derbyshireAtIndex(int index) { + Poly poly; + poly[0] = 1; + for (int i = 1; i <= N; i++) { + poly[i] = index & 1 ? 1 : -1; + index >>= 1; + } + assert(index == 0); + return poly; +} + +struct Job { + Poly init; + int numItems; +}; + +static vector<Job> derbyshireJobs(int targetJobs) { + int njobs = min(1 << N, ceil2(targetJobs)); + int jobsize = (1 << N) / njobs; + + vector<Job> jobs(njobs); + for (int i = 0; i < njobs; i++) { + jobs[i].init = derbyshireAtIndex(i * jobsize); + jobs[i].numItems = jobsize; + } + + return jobs; +} + +static tuple<int, int, vector<int>> computeCounts() { + constexpr const int W = 900; + constexpr const int H = 900; + constexpr const int numThreads = 4; + static_assert(ispow2(numThreads)); + constexpr const Com bottomLeft = Com(-1.5, -1.5); + constexpr const Com topRight = Com(1.5, 1.5); + + vector<int> counts(W * H); + mutex countsMutex; + + vector<Job> jobs = derbyshireJobs(numThreads); + assert(jobs.size() == numThreads); + + vector<thread> threads(jobs.size()); + for (int i = 0; i < (int)jobs.size(); i++) { + threads[i] = thread([&counts, &countsMutex, job = jobs[i], bottomLeft, topRight]() { + auto calcIndex = [](double value, double left, double right, int steps) -> int { + return (value - left) / (right - left) * (steps - 1) + 0.5; + }; + auto calcPos = [bottomLeft, topRight, &calcIndex](Com z) -> pair<int, int> { + return make_pair( + calcIndex(z.real(), bottomLeft.real(), topRight.real(), W), + calcIndex(z.imag(), bottomLeft.imag(), topRight.imag(), H) + ); + }; + + vector<int> localCounts(W * H); + + Poly poly = job.init; + for (int i = 0; i < job.numItems; i++) { + for (Com z : aberth(poly)) { + int x, y; + tie(x, y) = calcPos(z); + if (0 <= x && x < W && 0 <= y && y < H) { + localCounts[W * y + x]++; + } + } + nextDerbyshire(poly); + } + + lock_guard guard(countsMutex); + for (int i = 0; i < W * H; i++) counts[i] += localCounts[i]; + }); + } + + for (thread &th : threads) th.join(); + + return make_tuple(W, H, counts); +} + +static void writeCounts(int W, int H, const vector<int> &counts, const char *fname) { + ofstream f(fname); + f << W << ' ' << H << '\n'; + for (int y = 0; y < H; y++) { + for (int x = 0; x < W; x++) { + if (x != 0) f << ' '; + f << counts[W * y + x]; + } + f << '\n'; + } +} + +static tuple<int, int, vector<int>> readCounts(const char *fname) { + ifstream f(fname); + int W, H; + f >> W >> H; + vector<int> counts(W * H); + for (int &v : counts) f >> v; + return make_tuple(W, H, counts); +} + +static int rankCounts(vector<int> &counts) { + int maxcount = reduce(counts.begin(), counts.end(), 0, [](int a, int b) -> int { return max(a, b); }); + + vector<int> cumul(maxcount + 1, 0); + for (int v : counts) cumul[v]++; + cumul[0] = 0; + for (int i = 1; i < (int)cumul.size(); i++) cumul[i] += cumul[i-1]; + // assert(cumul[maxcount + 1] == (int)counts.size()); + + for (int &v : counts) v = cumul[v]; + + return cumul[maxcount]; +} + +static vector<uint8_t> drawImage(int W, int H, const vector<int> &counts, int maxcount) { + vector<uint8_t> image(3 * W * H); + + for (int y = 0; y < H; y++) { + for (int x = 0; x < W; x++) { + double value = (double)counts[W * y + x] / maxcount * 255; + image[3 * (W * y + x) + 0] = value; + image[3 * (W * y + x) + 1] = value; + image[3 * (W * y + x) + 2] = value; + } + } + + return image; +} + +int main(int argc, char **argv) { + srandomdev(); + + int W, H; + vector<int> counts; + + if (argc <= 1) { + tie(W, H, counts) = computeCounts(); + writeCounts(W, H, counts, "out.txt"); + } else if (argc == 2) { + tie(W, H, counts) = readCounts(argv[1]); + } else { + cerr << "Usage: " << argv[0] << " -- compute and draw" << endl; + cerr << "Usage: " << argv[0] << " <out.txt> -- draw already-computed data" << endl; + return 1; + } + + int maxcount = rankCounts(counts); + + vector<uint8_t> image = drawImage(W, H, counts, maxcount); + + assert(lodepng_encode24_file("out.png", image.data(), W, H) == 0); +} |