aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--aberth/aberth.cpp124
-rw-r--r--aberth/aberth_kernel.fut4
2 files changed, 98 insertions, 30 deletions
diff --git a/aberth/aberth.cpp b/aberth/aberth.cpp
index 30f40f6..c5b037d 100644
--- a/aberth/aberth.cpp
+++ b/aberth/aberth.cpp
@@ -11,6 +11,7 @@
#include <algorithm>
#include <thread>
#include <mutex>
+#include <type_traits>
#include <cstdint>
#include <cassert>
#include "../lodepng.h"
@@ -346,18 +347,11 @@ static vector<uint8_t> drawImage(int W, int H, const vector<int> &counts, int ma
return image;
}
-static vector<int32_t> invoke_kernel(
- int32_t width, int32_t height,
- Com bottomLeft, Com topRight,
- int32_t seed) {
+class Kernel {
+ futhark_context *ctx;
+ int32_t N;
- futhark_context_config *config = futhark_context_config_new();
- // futhark_context_config_select_device_interactively(config);
- // futhark_context_config_set_debugging(config, 1);
-
- futhark_context *ctx = futhark_context_new(config);
-
- auto check_ret = [ctx](int ret) {
+ void check_ret(int ret) {
if (ret != 0) {
char *str = futhark_context_get_error(ctx);
cerr << str << endl;
@@ -366,29 +360,101 @@ static vector<int32_t> invoke_kernel(
}
};
- futhark_i32_1d *dest_arr;
+public:
+ static_assert(is_same<int32_t, int>::value);
- check_ret(futhark_entry_main_all(
- ctx, &dest_arr,
- width, height,
- bottomLeft.real(), topRight.imag(),
- topRight.real(), bottomLeft.imag(),
- seed));
+ Kernel() {
+ futhark_context_config *config = futhark_context_config_new();
+ // futhark_context_config_select_device_interactively(config);
+ // futhark_context_config_set_debugging(config, 1);
- check_ret(futhark_context_sync(ctx));
+ ctx = futhark_context_new(config);
- // Shouldn't free _this_ pointer, apparently
- int64_t *shape = futhark_shape_i32_1d(ctx, dest_arr);
- assert(shape[0] == width * height);
+ futhark_context_config_free(config);
- vector<int32_t> buffer(width * height);
- check_ret(futhark_values_i32_1d(ctx, dest_arr, buffer.data()));
- check_ret(futhark_free_i32_1d(ctx, dest_arr));
+ check_ret(futhark_entry_get_N(ctx, &N));
+ // The 31 check is to not exceed 32-bit signed integer bounds
+ assert(N >= 1 && N < 31);
+ }
- futhark_context_config_free(config);
+ ~Kernel() {
+ futhark_context_free(ctx);
+ }
- return buffer;
-}
+ void run_job(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed,
+ int32_t start_index, int32_t poly_count) {
+
+ futhark_i32_1d *dest_arr;
+
+ check_ret(futhark_entry_main_job(
+ ctx, &dest_arr,
+ start_index, poly_count,
+ width, height,
+ bottomLeft.real(), topRight.imag(),
+ topRight.real(), bottomLeft.imag(),
+ seed));
+
+ check_ret(futhark_context_sync(ctx));
+
+ int64_t shape = futhark_shape_i32_1d(ctx, dest_arr)[0];
+ assert(shape == width * height);
+
+ dest.resize(width * height);
+ check_ret(futhark_values_i32_1d(ctx, dest_arr, dest.data()));
+ check_ret(futhark_free_i32_1d(ctx, dest_arr));
+ }
+
+ void run_all(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed) {
+
+ run_job(dest, width, height, bottomLeft, topRight, seed, 0, 1 << N);
+ }
+
+ void run_chunked(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed,
+ int32_t chunk_size) {
+
+ dest.clear();
+ dest.resize(width * height);
+
+ int32_t start_index = 0;
+ int32_t total = 1 << N;
+
+ int32_t njobs = (total + chunk_size - 1) / chunk_size;
+ cerr << "Running " << njobs << " jobs of size " << chunk_size << endl;
+ cerr << string(njobs, '.') << '\r';
+
+ while (start_index < total) {
+ int32_t num_polys = min(chunk_size, total - start_index);
+
+ vector<int32_t> output;
+ run_job(
+ output,
+ width, height, bottomLeft, topRight, seed,
+ start_index, num_polys);
+
+ for (int i = 0; i < width * height; i++) {
+ dest[i] += output[i];
+ }
+
+ start_index += num_polys;
+
+ cerr << '|';
+ }
+
+ cerr << endl;
+ }
+};
int main(int argc, char **argv) {
int W, H;
@@ -399,7 +465,7 @@ int main(int argc, char **argv) {
W = H = 2000;
Com bottomLeft = Com(0, -1.6);
Com topRight = Com(1.6, 0);
- counts = invoke_kernel(W, H, bottomLeft, topRight, 42);
+ Kernel().run_chunked(counts, W, H, bottomLeft, topRight, 42, 1 << 14);
writeCounts(W, H, counts, "out.txt");
} else if (argc == 2) {
tie(W, H, counts) = readCounts(argv[1]);
diff --git a/aberth/aberth_kernel.fut b/aberth/aberth_kernel.fut
index be1fed0..e1e7aa3 100644
--- a/aberth/aberth_kernel.fut
+++ b/aberth/aberth_kernel.fut
@@ -7,7 +7,7 @@ module uniform_real = uniform_real_distribution f32 rand_engine
module cplx = mk_complex f32
type complex = cplx.complex
-let N = 18i32
+let N = 22i32
let PolyN = N + 1
type poly = [PolyN]f32
@@ -167,3 +167,5 @@ entry main_all
(seed: i32)
: []i32 =
main_job 0 (1 << N) width height left top right bottom seed
+
+entry get_N: i32 = N