diff options
Diffstat (limited to 'aberth/main.cpp')
-rw-r--r-- | aberth/main.cpp | 174 |
1 files changed, 5 insertions, 169 deletions
diff --git a/aberth/main.cpp b/aberth/main.cpp index eb3cef7..1cdea3a 100644 --- a/aberth/main.cpp +++ b/aberth/main.cpp @@ -15,68 +15,13 @@ #include <cstdint> #include <cassert> #include "../lodepng.h" - -extern "C" { -#include "aberth_kernel.h" -} +#include "defs.h" +#include "kernel.h" +#include "util.h" using namespace std; -constexpr const int N = 18; - -using Com = complex<double>; - -using Poly = array<int, N + 1>; - -using AApprox = array<Com, N>; - - -template <typename T> -constexpr static T clearLowestBit(T value) { - return value & (value - 1); -} - -template <typename T> -constexpr static bool ispow2(T value) { - return clearLowestBit(value) == 0; -} - -template <typename T> -constexpr static T ceil2(T value) { - T value2 = clearLowestBit(value); - if (value2 == 0) return value; - - while (true) { - value = value2; - value2 = clearLowestBit(value); - if (value2 == 0) return value << 1; - } -} - -__attribute__((unused)) -static ostream& operator<<(ostream &os, const Poly &p) { - static const char *supers[10] = { - "⁰", "¹", "²", "³", "⁴", "⁵", "⁶", "⁷", "⁸", "⁹" - }; - os << p[0]; - for (int i = 1; i < (int)p.size(); i++) { - if (p[i] < 0) os << " - " << -p[i]; - else if (p[i] > 0) os << " + " << p[i]; - else continue; - - os << "x"; - - if (i == 1) continue; - - ostringstream ss; - ss << i; - string s = ss.str(); - for (char c : s) os << supers[c - '0']; - } - return os; -} - template <typename T> static T eval(const Poly &p, int nterms, T pt) { T value = p[nterms - 1]; @@ -335,115 +280,6 @@ static vector<uint8_t> drawImage(int W, int H, const vector<int> &counts, int ma return image; } -class Kernel { - futhark_context *ctx; - int32_t N; - - void check_ret(int ret) { - if (ret != 0) { - char *str = futhark_context_get_error(ctx); - cerr << str << endl; - free(str); - exit(1); - } - }; - -public: - static_assert(is_same<int32_t, int>::value); - - Kernel() { - futhark_context_config *config = futhark_context_config_new(); - // futhark_context_config_select_device_interactively(config); - // futhark_context_config_set_debugging(config, 1); - - ctx = futhark_context_new(config); - - futhark_context_config_free(config); - - check_ret(futhark_entry_get_N(ctx, &N)); - // The 31 check is to not exceed 32-bit signed integer bounds - assert(N >= 1 && N < 31); - } - - ~Kernel() { - futhark_context_free(ctx); - } - - void run_job( - vector<int32_t> &dest, - int32_t width, int32_t height, - Com bottomLeft, Com topRight, - int32_t seed, - int32_t start_index, int32_t poly_count) { - - futhark_i32_1d *dest_arr; - - check_ret(futhark_entry_main_job( - ctx, &dest_arr, - start_index, poly_count, - width, height, - bottomLeft.real(), topRight.imag(), - topRight.real(), bottomLeft.imag(), - seed)); - - check_ret(futhark_context_sync(ctx)); - - int64_t shape = futhark_shape_i32_1d(ctx, dest_arr)[0]; - assert(shape == width * height); - - dest.resize(width * height); - check_ret(futhark_values_i32_1d(ctx, dest_arr, dest.data())); - check_ret(futhark_free_i32_1d(ctx, dest_arr)); - } - - void run_all( - vector<int32_t> &dest, - int32_t width, int32_t height, - Com bottomLeft, Com topRight, - int32_t seed) { - - run_job(dest, width, height, bottomLeft, topRight, seed, 0, 1 << N); - } - - void run_chunked( - vector<int32_t> &dest, - int32_t width, int32_t height, - Com bottomLeft, Com topRight, - int32_t seed, - int32_t chunk_size) { - - dest.clear(); - dest.resize(width * height); - - int32_t start_index = 0; - int32_t total = 1 << N; - - int32_t njobs = (total + chunk_size - 1) / chunk_size; - cerr << "Running " << njobs << " jobs of size " << chunk_size << endl; - cerr << string(njobs, '.') << '\r'; - - while (start_index < total) { - int32_t num_polys = min(chunk_size, total - start_index); - - vector<int32_t> output; - run_job( - output, - width, height, bottomLeft, topRight, seed, - start_index, num_polys); - - for (int i = 0; i < width * height; i++) { - dest[i] += output[i]; - } - - start_index += num_polys; - - cerr << '|'; - } - - cerr << endl; - } -}; - int main(int argc, char **argv) { int W, H; vector<int> counts; @@ -453,8 +289,8 @@ int main(int argc, char **argv) { const Com bottomLeft = Com(-1.5, -1.5); const Com topRight = Com(1.5, 1.5); - counts = computeCounts(W, H, bottomLeft, topRight); - // Kernel().run_chunked(counts, W, H, bottomLeft, topRight, 42, 1 << 14); + // counts = computeCounts(W, H, bottomLeft, topRight); + Kernel().run_chunked(counts, W, H, bottomLeft, topRight, 42, 1 << 14); // Kernel().run_all(counts, W, H, bottomLeft, topRight, 42); writeCounts(W, H, counts, "out.txt"); |