aboutsummaryrefslogtreecommitdiff
path: root/aberth/kernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'aberth/kernel.cpp')
-rw-r--r--aberth/kernel.cpp110
1 files changed, 110 insertions, 0 deletions
diff --git a/aberth/kernel.cpp b/aberth/kernel.cpp
new file mode 100644
index 0000000..36c3102
--- /dev/null
+++ b/aberth/kernel.cpp
@@ -0,0 +1,110 @@
+#include <iostream>
+#include "kernel.h"
+
+extern "C" {
+#include "aberth_kernel.h"
+}
+
+using namespace std;
+
+
+void Kernel::check_ret(int ret) {
+ if (ret != 0) {
+ char *str = futhark_context_get_error(ctx);
+ cerr << str << endl;
+ free(str);
+ exit(1);
+ }
+};
+
+Kernel::Kernel() {
+ futhark_context_config *config = futhark_context_config_new();
+ // futhark_context_config_select_device_interactively(config);
+ // futhark_context_config_set_debugging(config, 1);
+
+ ctx = futhark_context_new(config);
+
+ futhark_context_config_free(config);
+
+ check_ret(futhark_entry_get_N(ctx, &N));
+ // The 31 check is to not exceed 32-bit signed integer bounds
+ assert(N >= 1 && N < 31);
+}
+
+Kernel::~Kernel() {
+ futhark_context_free(ctx);
+}
+
+void Kernel::run_job(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed,
+ int32_t start_index, int32_t poly_count) {
+
+ futhark_i32_1d *dest_arr;
+
+ check_ret(futhark_entry_main_job(
+ ctx, &dest_arr,
+ start_index, poly_count,
+ width, height,
+ bottomLeft.real(), topRight.imag(),
+ topRight.real(), bottomLeft.imag(),
+ seed));
+
+ check_ret(futhark_context_sync(ctx));
+
+ int64_t shape = futhark_shape_i32_1d(ctx, dest_arr)[0];
+ assert(shape == width * height);
+
+ dest.resize(width * height);
+ check_ret(futhark_values_i32_1d(ctx, dest_arr, dest.data()));
+ check_ret(futhark_free_i32_1d(ctx, dest_arr));
+}
+
+void Kernel::run_all(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed) {
+
+ run_job(dest, width, height, bottomLeft, topRight, seed, 0, 1 << N);
+}
+
+void Kernel::run_chunked(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed,
+ int32_t chunk_size) {
+
+ dest.clear();
+ dest.resize(width * height);
+
+ int32_t start_index = 0;
+ int32_t total = 1 << N;
+
+ int32_t njobs = (total + chunk_size - 1) / chunk_size;
+ cerr << "Running " << njobs << " jobs of size " << chunk_size << endl;
+ cerr << string(njobs, '.') << '\r';
+
+ while (start_index < total) {
+ int32_t num_polys = min(chunk_size, total - start_index);
+
+ vector<int32_t> output;
+ run_job(
+ output,
+ width, height, bottomLeft, topRight, seed,
+ start_index, num_polys);
+
+ for (int i = 0; i < width * height; i++) {
+ dest[i] += output[i];
+ }
+
+ start_index += num_polys;
+
+ cerr << '|';
+ }
+
+ cerr << endl;
+}