diff options
Diffstat (limited to 'aberth/main.cpp')
-rw-r--r-- | aberth/main.cpp | 482 |
1 files changed, 482 insertions, 0 deletions
diff --git a/aberth/main.cpp b/aberth/main.cpp new file mode 100644 index 0000000..3b3c11e --- /dev/null +++ b/aberth/main.cpp @@ -0,0 +1,482 @@ +#include <iostream> +#include <fstream> +#include <sstream> +#include <vector> +#include <array> +#include <string> +#include <complex> +#include <random> +#include <utility> +#include <tuple> +#include <algorithm> +#include <thread> +#include <mutex> +#include <type_traits> +#include <cstdint> +#include <cassert> +#include "../lodepng.h" + +extern "C" { +#include "aberth_kernel.h" +} + +using namespace std; + + +constexpr const int N = 18; + +using Com = complex<double>; + +using Poly = array<int, N + 1>; + +using AApprox = array<Com, N>; + + +template <typename T> +constexpr static T clearLowestBit(T value) { + return value & (value - 1); +} + +template <typename T> +constexpr static bool ispow2(T value) { + return clearLowestBit(value) == 0; +} + +template <typename T> +constexpr static T ceil2(T value) { + T value2 = clearLowestBit(value); + if (value2 == 0) return value; + + while (true) { + value = value2; + value2 = clearLowestBit(value); + if (value2 == 0) return value << 1; + } +} + +__attribute__((unused)) +static ostream& operator<<(ostream &os, const Poly &p) { + static const char *supers[10] = { + "⁰", "¹", "²", "³", "⁴", "⁵", "⁶", "⁷", "⁸", "⁹" + }; + os << p[0]; + for (int i = 1; i < (int)p.size(); i++) { + if (p[i] < 0) os << " - " << -p[i]; + else if (p[i] > 0) os << " + " << p[i]; + else continue; + + os << "x"; + + if (i == 1) continue; + + ostringstream ss; + ss << i; + string s = ss.str(); + for (char c : s) os << supers[c - '0']; + } + return os; +} + +template <typename T> +static T eval(const Poly &p, int nterms, T pt) { + T value = p[nterms - 1]; + for (int i = nterms - 2; i >= 0; i--) { + value = pt * value + (double)p[i]; + } + return value; +} + +template <typename T> +static T eval(const Poly &p, T pt) { + return eval(p, p.size(), pt); +} + +static Poly derivative(const Poly &p) { + Poly res; + for (int i = res.size() - 2; i >= 0; i--) { + res[i] = (i+1) * p[i+1]; + } + return res; +} + +static double maxRootNorm(const Poly &poly) { + // Cauchy's bound: https://en.wikipedia.org/wiki/Geometrical_properties_of_polynomial_roots#Lagrange's_and_Cauchy's_bounds + + double value = 0; + double last = (double)poly.back(); + for (int i = 0; i < (int)poly.size() - 1; i++) { + value = max(value, abs(poly[i] / last)); + } + return 1 + value; +} + +static thread_local minstd_rand randgenerator = minstd_rand(random_device()()); + +struct AberthState { + const Poly &poly; + Poly deriv; + Poly boundPoly; + AApprox approx; + double radius; + + void regenerate() { + auto genCoord = [this]() { + return uniform_real_distribution<double>(-radius, radius)(randgenerator); + }; + for (int i = 0; i < N; i++) { + approx[i] = Com(genCoord(), genCoord()); + } + } + + // boundPoly is 's' in the stop condition formulated at p.189-190 of + // https://link.springer.com/article/10.1007%2FBF02207694 + + AberthState(const Poly &poly) + : poly(poly), deriv(derivative(poly)), radius(maxRootNorm(poly)) { + + regenerate(); + for (int i = 0; i <= N; i++) { + boundPoly[i] = abs(poly[i]) * (4 * i + 1); + } + } + + // Lagrange-style step where the new elements are computed in parallel from the previous values + bool step() { + array<Com, N * N> pairs; + for (int i = 0; i < N - 1; i++) { + for (int j = i + 1; j < N; j++) { + pairs[N * i + j] = 1.0 / (approx[i] - approx[j]); + } + } + + bool allConverged = true; + + AApprox newapprox; + AApprox offsets; + for (int i = 0; i < N; i++) { + Com pval = eval(poly, approx[i]); + Com derivval = eval(deriv, poly.size() - 1, approx[i]); + Com quo = pval / derivval; + Com sum = 0; + for (int j = 0; j < i; j++) sum -= pairs[N * j + i]; + for (int j = i + 1; j < N; j++) sum += pairs[N * i + j]; + offsets[i] = quo / (1.0 - quo * sum); + + // approx[i] -= offsets[i]; + newapprox[i] = approx[i] - offsets[i]; + + double sval = eval(boundPoly, abs(newapprox[i])); + if (abs(pval) > 1e-5 * sval) allConverged = false; + } + + approx = newapprox; + + return allConverged; + } + + void iterate() { + int tries = 1, stepIdx = 1; + while (!step()) { + stepIdx++; + + if (stepIdx > tries * 100) { + regenerate(); + stepIdx = 0; + tries++; + } + } + } +}; + +static AApprox aberth(const Poly &poly) { + AberthState state(poly); + state.iterate(); + return state.approx; +} + +// Set the constant coefficient to 1; nextDerbyshire will never change it +static Poly initDerbyshire() { + Poly poly; + poly[0] = 1; + fill(poly.begin() + 1, poly.end(), -1); + return poly; +} + +// Returns whether we just looped around +static bool nextDerbyshire(Poly &poly) { + for (int i = 1; i < (int)poly.size(); i++) { + if (poly[i] == -1) { + poly[i] = 1; + return false; + } + poly[i] = -1; + } + return true; +} + +static Poly derbyshireAtIndex(int index) { + Poly poly; + poly[0] = 1; + for (int i = 1; i <= N; i++) { + poly[i] = index & 1 ? 1 : -1; + index >>= 1; + } + assert(index == 0); + return poly; +} + +struct Job { + Poly init; + int numItems; +}; + +static vector<Job> derbyshireJobs(int targetJobs) { + int njobs = min(1 << N, ceil2(targetJobs)); + int jobsize = (1 << N) / njobs; + + vector<Job> jobs(njobs); + for (int i = 0; i < njobs; i++) { + jobs[i].init = derbyshireAtIndex(i * jobsize); + jobs[i].numItems = jobsize; + } + + return jobs; +} + +static vector<int> computeCounts(int W, int H, Com bottomLeft, Com topRight) { + constexpr const int numThreads = 4; + static_assert(ispow2(numThreads)); + + vector<int> counts(W * H); + mutex countsMutex; + + vector<Job> jobs = derbyshireJobs(numThreads); + assert(jobs.size() == numThreads); + + vector<thread> threads(jobs.size()); + for (int i = 0; i < (int)jobs.size(); i++) { + threads[i] = thread([W, H, &counts, &countsMutex, job = jobs[i], bottomLeft, topRight]() { + auto calcIndex = [](double value, double left, double right, int steps) -> int { + return (value - left) / (right - left) * (steps - 1) + 0.5; + }; + auto calcPos = [W, H, bottomLeft, topRight, &calcIndex](Com z) -> pair<int, int> { + return make_pair( + calcIndex(z.real(), bottomLeft.real(), topRight.real(), W), + calcIndex(z.imag(), bottomLeft.imag(), topRight.imag(), H) + ); + }; + + vector<int> localCounts(W * H); + + Poly poly = job.init; + for (int i = 0; i < job.numItems; i++) { + for (Com z : aberth(poly)) { + int x, y; + tie(x, y) = calcPos(z); + if (0 <= x && x < W && 0 <= y && y < H) { + localCounts[W * y + x]++; + } + } + nextDerbyshire(poly); + } + + lock_guard<mutex> guard(countsMutex); + for (int i = 0; i < W * H; i++) counts[i] += localCounts[i]; + }); + } + + for (thread &th : threads) th.join(); + + return counts; +} + +static void writeCounts(int W, int H, const vector<int> &counts, const char *fname) { + ofstream f(fname); + f << W << ' ' << H << '\n'; + for (int y = 0; y < H; y++) { + for (int x = 0; x < W; x++) { + if (x != 0) f << ' '; + f << counts[W * y + x]; + } + f << '\n'; + } +} + +static tuple<int, int, vector<int>> readCounts(const char *fname) { + ifstream f(fname); + int W, H; + f >> W >> H; + vector<int> counts(W * H); + for (int &v : counts) f >> v; + return make_tuple(W, H, counts); +} + +static int rankCounts(vector<int> &counts) { + int maxcount = 0; + for (int i = 0; i < (int)counts.size(); i++) { + maxcount = max(maxcount, counts[i]); + } + + vector<int> cumul(maxcount + 1, 0); + for (int v : counts) cumul[v]++; + cumul[0] = 0; + for (int i = 1; i < (int)cumul.size(); i++) cumul[i] += cumul[i-1]; + // assert(cumul[maxcount + 1] == (int)counts.size()); + + for (int &v : counts) v = cumul[v]; + + return cumul[maxcount]; +} + +static vector<uint8_t> drawImage(int W, int H, const vector<int> &counts, int maxcount) { + vector<uint8_t> image(3 * W * H); + + for (int y = 0; y < H; y++) { + for (int x = 0; x < W; x++) { + double value = (double)counts[W * y + x] / maxcount * 255; + image[3 * (W * y + x) + 0] = value; + image[3 * (W * y + x) + 1] = value; + image[3 * (W * y + x) + 2] = value; + } + } + + return image; +} + +class Kernel { + futhark_context *ctx; + int32_t N; + + void check_ret(int ret) { + if (ret != 0) { + char *str = futhark_context_get_error(ctx); + cerr << str << endl; + free(str); + exit(1); + } + }; + +public: + static_assert(is_same<int32_t, int>::value); + + Kernel() { + futhark_context_config *config = futhark_context_config_new(); + // futhark_context_config_select_device_interactively(config); + // futhark_context_config_set_debugging(config, 1); + + ctx = futhark_context_new(config); + + futhark_context_config_free(config); + + check_ret(futhark_entry_get_N(ctx, &N)); + // The 31 check is to not exceed 32-bit signed integer bounds + assert(N >= 1 && N < 31); + } + + ~Kernel() { + futhark_context_free(ctx); + } + + void run_job( + vector<int32_t> &dest, + int32_t width, int32_t height, + Com bottomLeft, Com topRight, + int32_t seed, + int32_t start_index, int32_t poly_count) { + + futhark_i32_1d *dest_arr; + + check_ret(futhark_entry_main_job( + ctx, &dest_arr, + start_index, poly_count, + width, height, + bottomLeft.real(), topRight.imag(), + topRight.real(), bottomLeft.imag(), + seed)); + + check_ret(futhark_context_sync(ctx)); + + int64_t shape = futhark_shape_i32_1d(ctx, dest_arr)[0]; + assert(shape == width * height); + + dest.resize(width * height); + check_ret(futhark_values_i32_1d(ctx, dest_arr, dest.data())); + check_ret(futhark_free_i32_1d(ctx, dest_arr)); + } + + void run_all( + vector<int32_t> &dest, + int32_t width, int32_t height, + Com bottomLeft, Com topRight, + int32_t seed) { + + run_job(dest, width, height, bottomLeft, topRight, seed, 0, 1 << N); + } + + void run_chunked( + vector<int32_t> &dest, + int32_t width, int32_t height, + Com bottomLeft, Com topRight, + int32_t seed, + int32_t chunk_size) { + + dest.clear(); + dest.resize(width * height); + + int32_t start_index = 0; + int32_t total = 1 << N; + + int32_t njobs = (total + chunk_size - 1) / chunk_size; + cerr << "Running " << njobs << " jobs of size " << chunk_size << endl; + cerr << string(njobs, '.') << '\r'; + + while (start_index < total) { + int32_t num_polys = min(chunk_size, total - start_index); + + vector<int32_t> output; + run_job( + output, + width, height, bottomLeft, topRight, seed, + start_index, num_polys); + + for (int i = 0; i < width * height; i++) { + dest[i] += output[i]; + } + + start_index += num_polys; + + cerr << '|'; + } + + cerr << endl; + } +}; + +int main(int argc, char **argv) { + int W, H; + vector<int> counts; + + if (argc <= 1) { + W = H = 900; + const Com bottomLeft = Com(-1.5, -1.5); + const Com topRight = Com(1.5, 1.5); + + counts = computeCounts(W, H, bottomLeft, topRight); + // Kernel().run_chunked(counts, W, H, bottomLeft, topRight, 42, 1 << 14); + // Kernel().run_all(counts, W, H, bottomLeft, topRight, 42); + + writeCounts(W, H, counts, "out.txt"); + } else if (argc == 2) { + tie(W, H, counts) = readCounts(argv[1]); + } else { + cerr << "Usage: " << argv[0] << " -- compute and draw" << endl; + cerr << "Usage: " << argv[0] << " <out.txt> -- draw already-computed data" << endl; + return 1; + } + + int maxcount = rankCounts(counts); + + vector<uint8_t> image = drawImage(W, H, counts, maxcount); + + assert(lodepng_encode24_file("out.png", image.data(), W, H) == 0); +} |