#include #include #include "kernel.h" extern "C" { #include "aberth_kernel.h" } using namespace std; void Kernel::check_ret(int ret) { if (ret != 0) { char *str = futhark_context_get_error(ctx); cerr << str << endl; free(str); exit(1); } }; Kernel::Kernel() { futhark_context_config *config = futhark_context_config_new(); // futhark_context_config_select_device_interactively(config); // futhark_context_config_set_debugging(config, 1); ctx = futhark_context_new(config); futhark_context_config_free(config); check_ret(futhark_entry_get_N(ctx, &N)); // The 31 check is to not exceed 32-bit signed integer bounds assert(N >= 1 && N < 31); } Kernel::~Kernel() { futhark_context_free(ctx); } void Kernel::run_job( vector &dest, int32_t width, int32_t height, Com bottomLeft, Com topRight, int32_t seed, int32_t start_index, int32_t poly_count) { futhark_i32_1d *dest_arr; check_ret(futhark_entry_main_job( ctx, &dest_arr, start_index, poly_count, width, height, bottomLeft.real(), topRight.imag(), topRight.real(), bottomLeft.imag(), seed)); check_ret(futhark_context_sync(ctx)); int64_t shape = futhark_shape_i32_1d(ctx, dest_arr)[0]; assert(shape == width * height); dest.resize(width * height); check_ret(futhark_values_i32_1d(ctx, dest_arr, dest.data())); check_ret(futhark_free_i32_1d(ctx, dest_arr)); } void Kernel::run_all( vector &dest, int32_t width, int32_t height, Com bottomLeft, Com topRight, int32_t seed) { run_job(dest, width, height, bottomLeft, topRight, seed, 0, 1 << N); } void Kernel::run_chunked( vector &dest, int32_t width, int32_t height, Com bottomLeft, Com topRight, int32_t seed, int32_t chunk_size) { dest.clear(); dest.resize(width * height); int32_t start_index = 0; int32_t total = 1 << N; int32_t njobs = (total + chunk_size - 1) / chunk_size; cerr << "Running " << njobs << " jobs of size " << chunk_size << endl; cerr << string(njobs, '.') << '\r'; while (start_index < total) { int32_t num_polys = min(chunk_size, total - start_index); vector output; run_job( output, width, height, bottomLeft, topRight, seed, start_index, num_polys); for (int i = 0; i < width * height; i++) { dest[i] += output[i]; } start_index += num_polys; cerr << '|'; } cerr << endl; }