aboutsummaryrefslogtreecommitdiff
path: root/aberth/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'aberth/main.cpp')
-rw-r--r--aberth/main.cpp482
1 files changed, 482 insertions, 0 deletions
diff --git a/aberth/main.cpp b/aberth/main.cpp
new file mode 100644
index 0000000..3b3c11e
--- /dev/null
+++ b/aberth/main.cpp
@@ -0,0 +1,482 @@
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <vector>
+#include <array>
+#include <string>
+#include <complex>
+#include <random>
+#include <utility>
+#include <tuple>
+#include <algorithm>
+#include <thread>
+#include <mutex>
+#include <type_traits>
+#include <cstdint>
+#include <cassert>
+#include "../lodepng.h"
+
+extern "C" {
+#include "aberth_kernel.h"
+}
+
+using namespace std;
+
+
+constexpr const int N = 18;
+
+using Com = complex<double>;
+
+using Poly = array<int, N + 1>;
+
+using AApprox = array<Com, N>;
+
+
+template <typename T>
+constexpr static T clearLowestBit(T value) {
+ return value & (value - 1);
+}
+
+template <typename T>
+constexpr static bool ispow2(T value) {
+ return clearLowestBit(value) == 0;
+}
+
+template <typename T>
+constexpr static T ceil2(T value) {
+ T value2 = clearLowestBit(value);
+ if (value2 == 0) return value;
+
+ while (true) {
+ value = value2;
+ value2 = clearLowestBit(value);
+ if (value2 == 0) return value << 1;
+ }
+}
+
+__attribute__((unused))
+static ostream& operator<<(ostream &os, const Poly &p) {
+ static const char *supers[10] = {
+ "⁰", "¹", "²", "³", "⁴", "⁵", "⁶", "⁷", "⁸", "⁹"
+ };
+ os << p[0];
+ for (int i = 1; i < (int)p.size(); i++) {
+ if (p[i] < 0) os << " - " << -p[i];
+ else if (p[i] > 0) os << " + " << p[i];
+ else continue;
+
+ os << "x";
+
+ if (i == 1) continue;
+
+ ostringstream ss;
+ ss << i;
+ string s = ss.str();
+ for (char c : s) os << supers[c - '0'];
+ }
+ return os;
+}
+
+template <typename T>
+static T eval(const Poly &p, int nterms, T pt) {
+ T value = p[nterms - 1];
+ for (int i = nterms - 2; i >= 0; i--) {
+ value = pt * value + (double)p[i];
+ }
+ return value;
+}
+
+template <typename T>
+static T eval(const Poly &p, T pt) {
+ return eval(p, p.size(), pt);
+}
+
+static Poly derivative(const Poly &p) {
+ Poly res;
+ for (int i = res.size() - 2; i >= 0; i--) {
+ res[i] = (i+1) * p[i+1];
+ }
+ return res;
+}
+
+static double maxRootNorm(const Poly &poly) {
+ // Cauchy's bound: https://en.wikipedia.org/wiki/Geometrical_properties_of_polynomial_roots#Lagrange's_and_Cauchy's_bounds
+
+ double value = 0;
+ double last = (double)poly.back();
+ for (int i = 0; i < (int)poly.size() - 1; i++) {
+ value = max(value, abs(poly[i] / last));
+ }
+ return 1 + value;
+}
+
+static thread_local minstd_rand randgenerator = minstd_rand(random_device()());
+
+struct AberthState {
+ const Poly &poly;
+ Poly deriv;
+ Poly boundPoly;
+ AApprox approx;
+ double radius;
+
+ void regenerate() {
+ auto genCoord = [this]() {
+ return uniform_real_distribution<double>(-radius, radius)(randgenerator);
+ };
+ for (int i = 0; i < N; i++) {
+ approx[i] = Com(genCoord(), genCoord());
+ }
+ }
+
+ // boundPoly is 's' in the stop condition formulated at p.189-190 of
+ // https://link.springer.com/article/10.1007%2FBF02207694
+
+ AberthState(const Poly &poly)
+ : poly(poly), deriv(derivative(poly)), radius(maxRootNorm(poly)) {
+
+ regenerate();
+ for (int i = 0; i <= N; i++) {
+ boundPoly[i] = abs(poly[i]) * (4 * i + 1);
+ }
+ }
+
+ // Lagrange-style step where the new elements are computed in parallel from the previous values
+ bool step() {
+ array<Com, N * N> pairs;
+ for (int i = 0; i < N - 1; i++) {
+ for (int j = i + 1; j < N; j++) {
+ pairs[N * i + j] = 1.0 / (approx[i] - approx[j]);
+ }
+ }
+
+ bool allConverged = true;
+
+ AApprox newapprox;
+ AApprox offsets;
+ for (int i = 0; i < N; i++) {
+ Com pval = eval(poly, approx[i]);
+ Com derivval = eval(deriv, poly.size() - 1, approx[i]);
+ Com quo = pval / derivval;
+ Com sum = 0;
+ for (int j = 0; j < i; j++) sum -= pairs[N * j + i];
+ for (int j = i + 1; j < N; j++) sum += pairs[N * i + j];
+ offsets[i] = quo / (1.0 - quo * sum);
+
+ // approx[i] -= offsets[i];
+ newapprox[i] = approx[i] - offsets[i];
+
+ double sval = eval(boundPoly, abs(newapprox[i]));
+ if (abs(pval) > 1e-5 * sval) allConverged = false;
+ }
+
+ approx = newapprox;
+
+ return allConverged;
+ }
+
+ void iterate() {
+ int tries = 1, stepIdx = 1;
+ while (!step()) {
+ stepIdx++;
+
+ if (stepIdx > tries * 100) {
+ regenerate();
+ stepIdx = 0;
+ tries++;
+ }
+ }
+ }
+};
+
+static AApprox aberth(const Poly &poly) {
+ AberthState state(poly);
+ state.iterate();
+ return state.approx;
+}
+
+// Set the constant coefficient to 1; nextDerbyshire will never change it
+static Poly initDerbyshire() {
+ Poly poly;
+ poly[0] = 1;
+ fill(poly.begin() + 1, poly.end(), -1);
+ return poly;
+}
+
+// Returns whether we just looped around
+static bool nextDerbyshire(Poly &poly) {
+ for (int i = 1; i < (int)poly.size(); i++) {
+ if (poly[i] == -1) {
+ poly[i] = 1;
+ return false;
+ }
+ poly[i] = -1;
+ }
+ return true;
+}
+
+static Poly derbyshireAtIndex(int index) {
+ Poly poly;
+ poly[0] = 1;
+ for (int i = 1; i <= N; i++) {
+ poly[i] = index & 1 ? 1 : -1;
+ index >>= 1;
+ }
+ assert(index == 0);
+ return poly;
+}
+
+struct Job {
+ Poly init;
+ int numItems;
+};
+
+static vector<Job> derbyshireJobs(int targetJobs) {
+ int njobs = min(1 << N, ceil2(targetJobs));
+ int jobsize = (1 << N) / njobs;
+
+ vector<Job> jobs(njobs);
+ for (int i = 0; i < njobs; i++) {
+ jobs[i].init = derbyshireAtIndex(i * jobsize);
+ jobs[i].numItems = jobsize;
+ }
+
+ return jobs;
+}
+
+static vector<int> computeCounts(int W, int H, Com bottomLeft, Com topRight) {
+ constexpr const int numThreads = 4;
+ static_assert(ispow2(numThreads));
+
+ vector<int> counts(W * H);
+ mutex countsMutex;
+
+ vector<Job> jobs = derbyshireJobs(numThreads);
+ assert(jobs.size() == numThreads);
+
+ vector<thread> threads(jobs.size());
+ for (int i = 0; i < (int)jobs.size(); i++) {
+ threads[i] = thread([W, H, &counts, &countsMutex, job = jobs[i], bottomLeft, topRight]() {
+ auto calcIndex = [](double value, double left, double right, int steps) -> int {
+ return (value - left) / (right - left) * (steps - 1) + 0.5;
+ };
+ auto calcPos = [W, H, bottomLeft, topRight, &calcIndex](Com z) -> pair<int, int> {
+ return make_pair(
+ calcIndex(z.real(), bottomLeft.real(), topRight.real(), W),
+ calcIndex(z.imag(), bottomLeft.imag(), topRight.imag(), H)
+ );
+ };
+
+ vector<int> localCounts(W * H);
+
+ Poly poly = job.init;
+ for (int i = 0; i < job.numItems; i++) {
+ for (Com z : aberth(poly)) {
+ int x, y;
+ tie(x, y) = calcPos(z);
+ if (0 <= x && x < W && 0 <= y && y < H) {
+ localCounts[W * y + x]++;
+ }
+ }
+ nextDerbyshire(poly);
+ }
+
+ lock_guard<mutex> guard(countsMutex);
+ for (int i = 0; i < W * H; i++) counts[i] += localCounts[i];
+ });
+ }
+
+ for (thread &th : threads) th.join();
+
+ return counts;
+}
+
+static void writeCounts(int W, int H, const vector<int> &counts, const char *fname) {
+ ofstream f(fname);
+ f << W << ' ' << H << '\n';
+ for (int y = 0; y < H; y++) {
+ for (int x = 0; x < W; x++) {
+ if (x != 0) f << ' ';
+ f << counts[W * y + x];
+ }
+ f << '\n';
+ }
+}
+
+static tuple<int, int, vector<int>> readCounts(const char *fname) {
+ ifstream f(fname);
+ int W, H;
+ f >> W >> H;
+ vector<int> counts(W * H);
+ for (int &v : counts) f >> v;
+ return make_tuple(W, H, counts);
+}
+
+static int rankCounts(vector<int> &counts) {
+ int maxcount = 0;
+ for (int i = 0; i < (int)counts.size(); i++) {
+ maxcount = max(maxcount, counts[i]);
+ }
+
+ vector<int> cumul(maxcount + 1, 0);
+ for (int v : counts) cumul[v]++;
+ cumul[0] = 0;
+ for (int i = 1; i < (int)cumul.size(); i++) cumul[i] += cumul[i-1];
+ // assert(cumul[maxcount + 1] == (int)counts.size());
+
+ for (int &v : counts) v = cumul[v];
+
+ return cumul[maxcount];
+}
+
+static vector<uint8_t> drawImage(int W, int H, const vector<int> &counts, int maxcount) {
+ vector<uint8_t> image(3 * W * H);
+
+ for (int y = 0; y < H; y++) {
+ for (int x = 0; x < W; x++) {
+ double value = (double)counts[W * y + x] / maxcount * 255;
+ image[3 * (W * y + x) + 0] = value;
+ image[3 * (W * y + x) + 1] = value;
+ image[3 * (W * y + x) + 2] = value;
+ }
+ }
+
+ return image;
+}
+
+class Kernel {
+ futhark_context *ctx;
+ int32_t N;
+
+ void check_ret(int ret) {
+ if (ret != 0) {
+ char *str = futhark_context_get_error(ctx);
+ cerr << str << endl;
+ free(str);
+ exit(1);
+ }
+ };
+
+public:
+ static_assert(is_same<int32_t, int>::value);
+
+ Kernel() {
+ futhark_context_config *config = futhark_context_config_new();
+ // futhark_context_config_select_device_interactively(config);
+ // futhark_context_config_set_debugging(config, 1);
+
+ ctx = futhark_context_new(config);
+
+ futhark_context_config_free(config);
+
+ check_ret(futhark_entry_get_N(ctx, &N));
+ // The 31 check is to not exceed 32-bit signed integer bounds
+ assert(N >= 1 && N < 31);
+ }
+
+ ~Kernel() {
+ futhark_context_free(ctx);
+ }
+
+ void run_job(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed,
+ int32_t start_index, int32_t poly_count) {
+
+ futhark_i32_1d *dest_arr;
+
+ check_ret(futhark_entry_main_job(
+ ctx, &dest_arr,
+ start_index, poly_count,
+ width, height,
+ bottomLeft.real(), topRight.imag(),
+ topRight.real(), bottomLeft.imag(),
+ seed));
+
+ check_ret(futhark_context_sync(ctx));
+
+ int64_t shape = futhark_shape_i32_1d(ctx, dest_arr)[0];
+ assert(shape == width * height);
+
+ dest.resize(width * height);
+ check_ret(futhark_values_i32_1d(ctx, dest_arr, dest.data()));
+ check_ret(futhark_free_i32_1d(ctx, dest_arr));
+ }
+
+ void run_all(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed) {
+
+ run_job(dest, width, height, bottomLeft, topRight, seed, 0, 1 << N);
+ }
+
+ void run_chunked(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed,
+ int32_t chunk_size) {
+
+ dest.clear();
+ dest.resize(width * height);
+
+ int32_t start_index = 0;
+ int32_t total = 1 << N;
+
+ int32_t njobs = (total + chunk_size - 1) / chunk_size;
+ cerr << "Running " << njobs << " jobs of size " << chunk_size << endl;
+ cerr << string(njobs, '.') << '\r';
+
+ while (start_index < total) {
+ int32_t num_polys = min(chunk_size, total - start_index);
+
+ vector<int32_t> output;
+ run_job(
+ output,
+ width, height, bottomLeft, topRight, seed,
+ start_index, num_polys);
+
+ for (int i = 0; i < width * height; i++) {
+ dest[i] += output[i];
+ }
+
+ start_index += num_polys;
+
+ cerr << '|';
+ }
+
+ cerr << endl;
+ }
+};
+
+int main(int argc, char **argv) {
+ int W, H;
+ vector<int> counts;
+
+ if (argc <= 1) {
+ W = H = 900;
+ const Com bottomLeft = Com(-1.5, -1.5);
+ const Com topRight = Com(1.5, 1.5);
+
+ counts = computeCounts(W, H, bottomLeft, topRight);
+ // Kernel().run_chunked(counts, W, H, bottomLeft, topRight, 42, 1 << 14);
+ // Kernel().run_all(counts, W, H, bottomLeft, topRight, 42);
+
+ writeCounts(W, H, counts, "out.txt");
+ } else if (argc == 2) {
+ tie(W, H, counts) = readCounts(argv[1]);
+ } else {
+ cerr << "Usage: " << argv[0] << " -- compute and draw" << endl;
+ cerr << "Usage: " << argv[0] << " <out.txt> -- draw already-computed data" << endl;
+ return 1;
+ }
+
+ int maxcount = rankCounts(counts);
+
+ vector<uint8_t> image = drawImage(W, H, counts, maxcount);
+
+ assert(lodepng_encode24_file("out.png", image.data(), W, H) == 0);
+}