diff options
-rw-r--r-- | aberth/aberth.cpp | 124 | ||||
-rw-r--r-- | aberth/aberth_kernel.fut | 4 |
2 files changed, 98 insertions, 30 deletions
diff --git a/aberth/aberth.cpp b/aberth/aberth.cpp index 30f40f6..c5b037d 100644 --- a/aberth/aberth.cpp +++ b/aberth/aberth.cpp @@ -11,6 +11,7 @@ #include <algorithm> #include <thread> #include <mutex> +#include <type_traits> #include <cstdint> #include <cassert> #include "../lodepng.h" @@ -346,18 +347,11 @@ static vector<uint8_t> drawImage(int W, int H, const vector<int> &counts, int ma return image; } -static vector<int32_t> invoke_kernel( - int32_t width, int32_t height, - Com bottomLeft, Com topRight, - int32_t seed) { +class Kernel { + futhark_context *ctx; + int32_t N; - futhark_context_config *config = futhark_context_config_new(); - // futhark_context_config_select_device_interactively(config); - // futhark_context_config_set_debugging(config, 1); - - futhark_context *ctx = futhark_context_new(config); - - auto check_ret = [ctx](int ret) { + void check_ret(int ret) { if (ret != 0) { char *str = futhark_context_get_error(ctx); cerr << str << endl; @@ -366,29 +360,101 @@ static vector<int32_t> invoke_kernel( } }; - futhark_i32_1d *dest_arr; +public: + static_assert(is_same<int32_t, int>::value); - check_ret(futhark_entry_main_all( - ctx, &dest_arr, - width, height, - bottomLeft.real(), topRight.imag(), - topRight.real(), bottomLeft.imag(), - seed)); + Kernel() { + futhark_context_config *config = futhark_context_config_new(); + // futhark_context_config_select_device_interactively(config); + // futhark_context_config_set_debugging(config, 1); - check_ret(futhark_context_sync(ctx)); + ctx = futhark_context_new(config); - // Shouldn't free _this_ pointer, apparently - int64_t *shape = futhark_shape_i32_1d(ctx, dest_arr); - assert(shape[0] == width * height); + futhark_context_config_free(config); - vector<int32_t> buffer(width * height); - check_ret(futhark_values_i32_1d(ctx, dest_arr, buffer.data())); - check_ret(futhark_free_i32_1d(ctx, dest_arr)); + check_ret(futhark_entry_get_N(ctx, &N)); + // The 31 check is to not exceed 32-bit signed integer bounds + assert(N >= 1 && N < 31); + } - futhark_context_config_free(config); + ~Kernel() { + futhark_context_free(ctx); + } - return buffer; -} + void run_job( + vector<int32_t> &dest, + int32_t width, int32_t height, + Com bottomLeft, Com topRight, + int32_t seed, + int32_t start_index, int32_t poly_count) { + + futhark_i32_1d *dest_arr; + + check_ret(futhark_entry_main_job( + ctx, &dest_arr, + start_index, poly_count, + width, height, + bottomLeft.real(), topRight.imag(), + topRight.real(), bottomLeft.imag(), + seed)); + + check_ret(futhark_context_sync(ctx)); + + int64_t shape = futhark_shape_i32_1d(ctx, dest_arr)[0]; + assert(shape == width * height); + + dest.resize(width * height); + check_ret(futhark_values_i32_1d(ctx, dest_arr, dest.data())); + check_ret(futhark_free_i32_1d(ctx, dest_arr)); + } + + void run_all( + vector<int32_t> &dest, + int32_t width, int32_t height, + Com bottomLeft, Com topRight, + int32_t seed) { + + run_job(dest, width, height, bottomLeft, topRight, seed, 0, 1 << N); + } + + void run_chunked( + vector<int32_t> &dest, + int32_t width, int32_t height, + Com bottomLeft, Com topRight, + int32_t seed, + int32_t chunk_size) { + + dest.clear(); + dest.resize(width * height); + + int32_t start_index = 0; + int32_t total = 1 << N; + + int32_t njobs = (total + chunk_size - 1) / chunk_size; + cerr << "Running " << njobs << " jobs of size " << chunk_size << endl; + cerr << string(njobs, '.') << '\r'; + + while (start_index < total) { + int32_t num_polys = min(chunk_size, total - start_index); + + vector<int32_t> output; + run_job( + output, + width, height, bottomLeft, topRight, seed, + start_index, num_polys); + + for (int i = 0; i < width * height; i++) { + dest[i] += output[i]; + } + + start_index += num_polys; + + cerr << '|'; + } + + cerr << endl; + } +}; int main(int argc, char **argv) { int W, H; @@ -399,7 +465,7 @@ int main(int argc, char **argv) { W = H = 2000; Com bottomLeft = Com(0, -1.6); Com topRight = Com(1.6, 0); - counts = invoke_kernel(W, H, bottomLeft, topRight, 42); + Kernel().run_chunked(counts, W, H, bottomLeft, topRight, 42, 1 << 14); writeCounts(W, H, counts, "out.txt"); } else if (argc == 2) { tie(W, H, counts) = readCounts(argv[1]); diff --git a/aberth/aberth_kernel.fut b/aberth/aberth_kernel.fut index be1fed0..e1e7aa3 100644 --- a/aberth/aberth_kernel.fut +++ b/aberth/aberth_kernel.fut @@ -7,7 +7,7 @@ module uniform_real = uniform_real_distribution f32 rand_engine module cplx = mk_complex f32 type complex = cplx.complex -let N = 18i32 +let N = 22i32 let PolyN = N + 1 type poly = [PolyN]f32 @@ -167,3 +167,5 @@ entry main_all (seed: i32) : []i32 = main_job 0 (1 << N) width height left top right bottom seed + +entry get_N: i32 = N |