diff options
Diffstat (limited to 'aberth/kernel.cpp')
-rw-r--r-- | aberth/kernel.cpp | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/aberth/kernel.cpp b/aberth/kernel.cpp new file mode 100644 index 0000000..36c3102 --- /dev/null +++ b/aberth/kernel.cpp @@ -0,0 +1,110 @@ +#include <iostream> +#include "kernel.h" + +extern "C" { +#include "aberth_kernel.h" +} + +using namespace std; + + +void Kernel::check_ret(int ret) { + if (ret != 0) { + char *str = futhark_context_get_error(ctx); + cerr << str << endl; + free(str); + exit(1); + } +}; + +Kernel::Kernel() { + futhark_context_config *config = futhark_context_config_new(); + // futhark_context_config_select_device_interactively(config); + // futhark_context_config_set_debugging(config, 1); + + ctx = futhark_context_new(config); + + futhark_context_config_free(config); + + check_ret(futhark_entry_get_N(ctx, &N)); + // The 31 check is to not exceed 32-bit signed integer bounds + assert(N >= 1 && N < 31); +} + +Kernel::~Kernel() { + futhark_context_free(ctx); +} + +void Kernel::run_job( + vector<int32_t> &dest, + int32_t width, int32_t height, + Com bottomLeft, Com topRight, + int32_t seed, + int32_t start_index, int32_t poly_count) { + + futhark_i32_1d *dest_arr; + + check_ret(futhark_entry_main_job( + ctx, &dest_arr, + start_index, poly_count, + width, height, + bottomLeft.real(), topRight.imag(), + topRight.real(), bottomLeft.imag(), + seed)); + + check_ret(futhark_context_sync(ctx)); + + int64_t shape = futhark_shape_i32_1d(ctx, dest_arr)[0]; + assert(shape == width * height); + + dest.resize(width * height); + check_ret(futhark_values_i32_1d(ctx, dest_arr, dest.data())); + check_ret(futhark_free_i32_1d(ctx, dest_arr)); +} + +void Kernel::run_all( + vector<int32_t> &dest, + int32_t width, int32_t height, + Com bottomLeft, Com topRight, + int32_t seed) { + + run_job(dest, width, height, bottomLeft, topRight, seed, 0, 1 << N); +} + +void Kernel::run_chunked( + vector<int32_t> &dest, + int32_t width, int32_t height, + Com bottomLeft, Com topRight, + int32_t seed, + int32_t chunk_size) { + + dest.clear(); + dest.resize(width * height); + + int32_t start_index = 0; + int32_t total = 1 << N; + + int32_t njobs = (total + chunk_size - 1) / chunk_size; + cerr << "Running " << njobs << " jobs of size " << chunk_size << endl; + cerr << string(njobs, '.') << '\r'; + + while (start_index < total) { + int32_t num_polys = min(chunk_size, total - start_index); + + vector<int32_t> output; + run_job( + output, + width, height, bottomLeft, topRight, seed, + start_index, num_polys); + + for (int i = 0; i < width * height; i++) { + dest[i] += output[i]; + } + + start_index += num_polys; + + cerr << '|'; + } + + cerr << endl; +} |