aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortomsmeding <tom.smeding@gmail.com>2019-04-20 19:41:16 +0200
committertomsmeding <tom.smeding@gmail.com>2019-04-20 19:41:16 +0200
commitccade8e6ed96fb48b329d22aaa4c8d0826b3c8d1 (patch)
treeb9389a557869a38d0cdb9de02247abf1446e5db3
parentdc6f869c48c267e2091d686a66209c4f741d62e4 (diff)
Move some code out of main.cpp
-rw-r--r--aberth/defs.h13
-rw-r--r--aberth/kernel.cpp110
-rw-r--r--aberth/kernel.h44
-rw-r--r--aberth/main.cpp174
-rw-r--r--aberth/util.cpp27
-rw-r--r--aberth/util.h31
6 files changed, 230 insertions, 169 deletions
diff --git a/aberth/defs.h b/aberth/defs.h
new file mode 100644
index 0000000..40eacf0
--- /dev/null
+++ b/aberth/defs.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include <complex>
+#include <array>
+
+using namespace std;
+
+
+constexpr const int N = 18;
+
+using Com = complex<double>;
+using Poly = array<int, N + 1>;
+using AApprox = array<Com, N>;
diff --git a/aberth/kernel.cpp b/aberth/kernel.cpp
new file mode 100644
index 0000000..36c3102
--- /dev/null
+++ b/aberth/kernel.cpp
@@ -0,0 +1,110 @@
+#include <iostream>
+#include "kernel.h"
+
+extern "C" {
+#include "aberth_kernel.h"
+}
+
+using namespace std;
+
+
+void Kernel::check_ret(int ret) {
+ if (ret != 0) {
+ char *str = futhark_context_get_error(ctx);
+ cerr << str << endl;
+ free(str);
+ exit(1);
+ }
+};
+
+Kernel::Kernel() {
+ futhark_context_config *config = futhark_context_config_new();
+ // futhark_context_config_select_device_interactively(config);
+ // futhark_context_config_set_debugging(config, 1);
+
+ ctx = futhark_context_new(config);
+
+ futhark_context_config_free(config);
+
+ check_ret(futhark_entry_get_N(ctx, &N));
+ // The 31 check is to not exceed 32-bit signed integer bounds
+ assert(N >= 1 && N < 31);
+}
+
+Kernel::~Kernel() {
+ futhark_context_free(ctx);
+}
+
+void Kernel::run_job(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed,
+ int32_t start_index, int32_t poly_count) {
+
+ futhark_i32_1d *dest_arr;
+
+ check_ret(futhark_entry_main_job(
+ ctx, &dest_arr,
+ start_index, poly_count,
+ width, height,
+ bottomLeft.real(), topRight.imag(),
+ topRight.real(), bottomLeft.imag(),
+ seed));
+
+ check_ret(futhark_context_sync(ctx));
+
+ int64_t shape = futhark_shape_i32_1d(ctx, dest_arr)[0];
+ assert(shape == width * height);
+
+ dest.resize(width * height);
+ check_ret(futhark_values_i32_1d(ctx, dest_arr, dest.data()));
+ check_ret(futhark_free_i32_1d(ctx, dest_arr));
+}
+
+void Kernel::run_all(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed) {
+
+ run_job(dest, width, height, bottomLeft, topRight, seed, 0, 1 << N);
+}
+
+void Kernel::run_chunked(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed,
+ int32_t chunk_size) {
+
+ dest.clear();
+ dest.resize(width * height);
+
+ int32_t start_index = 0;
+ int32_t total = 1 << N;
+
+ int32_t njobs = (total + chunk_size - 1) / chunk_size;
+ cerr << "Running " << njobs << " jobs of size " << chunk_size << endl;
+ cerr << string(njobs, '.') << '\r';
+
+ while (start_index < total) {
+ int32_t num_polys = min(chunk_size, total - start_index);
+
+ vector<int32_t> output;
+ run_job(
+ output,
+ width, height, bottomLeft, topRight, seed,
+ start_index, num_polys);
+
+ for (int i = 0; i < width * height; i++) {
+ dest[i] += output[i];
+ }
+
+ start_index += num_polys;
+
+ cerr << '|';
+ }
+
+ cerr << endl;
+}
diff --git a/aberth/kernel.h b/aberth/kernel.h
new file mode 100644
index 0000000..bd8ea32
--- /dev/null
+++ b/aberth/kernel.h
@@ -0,0 +1,44 @@
+#pragma once
+
+#include <vector>
+#include "defs.h"
+
+extern "C" {
+#include "aberth_kernel.h"
+}
+
+using namespace std;
+
+
+class Kernel {
+ futhark_context *ctx;
+ int32_t N;
+
+ void check_ret(int ret);
+
+public:
+ static_assert(is_same<int32_t, int>::value);
+
+ Kernel();
+ ~Kernel();
+
+ void run_job(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed,
+ int32_t start_index, int32_t poly_count);
+
+ void run_all(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed);
+
+ void run_chunked(
+ vector<int32_t> &dest,
+ int32_t width, int32_t height,
+ Com bottomLeft, Com topRight,
+ int32_t seed,
+ int32_t chunk_size);
+};
diff --git a/aberth/main.cpp b/aberth/main.cpp
index eb3cef7..1cdea3a 100644
--- a/aberth/main.cpp
+++ b/aberth/main.cpp
@@ -15,68 +15,13 @@
#include <cstdint>
#include <cassert>
#include "../lodepng.h"
-
-extern "C" {
-#include "aberth_kernel.h"
-}
+#include "defs.h"
+#include "kernel.h"
+#include "util.h"
using namespace std;
-constexpr const int N = 18;
-
-using Com = complex<double>;
-
-using Poly = array<int, N + 1>;
-
-using AApprox = array<Com, N>;
-
-
-template <typename T>
-constexpr static T clearLowestBit(T value) {
- return value & (value - 1);
-}
-
-template <typename T>
-constexpr static bool ispow2(T value) {
- return clearLowestBit(value) == 0;
-}
-
-template <typename T>
-constexpr static T ceil2(T value) {
- T value2 = clearLowestBit(value);
- if (value2 == 0) return value;
-
- while (true) {
- value = value2;
- value2 = clearLowestBit(value);
- if (value2 == 0) return value << 1;
- }
-}
-
-__attribute__((unused))
-static ostream& operator<<(ostream &os, const Poly &p) {
- static const char *supers[10] = {
- "⁰", "¹", "²", "³", "⁴", "⁵", "⁶", "⁷", "⁸", "⁹"
- };
- os << p[0];
- for (int i = 1; i < (int)p.size(); i++) {
- if (p[i] < 0) os << " - " << -p[i];
- else if (p[i] > 0) os << " + " << p[i];
- else continue;
-
- os << "x";
-
- if (i == 1) continue;
-
- ostringstream ss;
- ss << i;
- string s = ss.str();
- for (char c : s) os << supers[c - '0'];
- }
- return os;
-}
-
template <typename T>
static T eval(const Poly &p, int nterms, T pt) {
T value = p[nterms - 1];
@@ -335,115 +280,6 @@ static vector<uint8_t> drawImage(int W, int H, const vector<int> &counts, int ma
return image;
}
-class Kernel {
- futhark_context *ctx;
- int32_t N;
-
- void check_ret(int ret) {
- if (ret != 0) {
- char *str = futhark_context_get_error(ctx);
- cerr << str << endl;
- free(str);
- exit(1);
- }
- };
-
-public:
- static_assert(is_same<int32_t, int>::value);
-
- Kernel() {
- futhark_context_config *config = futhark_context_config_new();
- // futhark_context_config_select_device_interactively(config);
- // futhark_context_config_set_debugging(config, 1);
-
- ctx = futhark_context_new(config);
-
- futhark_context_config_free(config);
-
- check_ret(futhark_entry_get_N(ctx, &N));
- // The 31 check is to not exceed 32-bit signed integer bounds
- assert(N >= 1 && N < 31);
- }
-
- ~Kernel() {
- futhark_context_free(ctx);
- }
-
- void run_job(
- vector<int32_t> &dest,
- int32_t width, int32_t height,
- Com bottomLeft, Com topRight,
- int32_t seed,
- int32_t start_index, int32_t poly_count) {
-
- futhark_i32_1d *dest_arr;
-
- check_ret(futhark_entry_main_job(
- ctx, &dest_arr,
- start_index, poly_count,
- width, height,
- bottomLeft.real(), topRight.imag(),
- topRight.real(), bottomLeft.imag(),
- seed));
-
- check_ret(futhark_context_sync(ctx));
-
- int64_t shape = futhark_shape_i32_1d(ctx, dest_arr)[0];
- assert(shape == width * height);
-
- dest.resize(width * height);
- check_ret(futhark_values_i32_1d(ctx, dest_arr, dest.data()));
- check_ret(futhark_free_i32_1d(ctx, dest_arr));
- }
-
- void run_all(
- vector<int32_t> &dest,
- int32_t width, int32_t height,
- Com bottomLeft, Com topRight,
- int32_t seed) {
-
- run_job(dest, width, height, bottomLeft, topRight, seed, 0, 1 << N);
- }
-
- void run_chunked(
- vector<int32_t> &dest,
- int32_t width, int32_t height,
- Com bottomLeft, Com topRight,
- int32_t seed,
- int32_t chunk_size) {
-
- dest.clear();
- dest.resize(width * height);
-
- int32_t start_index = 0;
- int32_t total = 1 << N;
-
- int32_t njobs = (total + chunk_size - 1) / chunk_size;
- cerr << "Running " << njobs << " jobs of size " << chunk_size << endl;
- cerr << string(njobs, '.') << '\r';
-
- while (start_index < total) {
- int32_t num_polys = min(chunk_size, total - start_index);
-
- vector<int32_t> output;
- run_job(
- output,
- width, height, bottomLeft, topRight, seed,
- start_index, num_polys);
-
- for (int i = 0; i < width * height; i++) {
- dest[i] += output[i];
- }
-
- start_index += num_polys;
-
- cerr << '|';
- }
-
- cerr << endl;
- }
-};
-
int main(int argc, char **argv) {
int W, H;
vector<int> counts;
@@ -453,8 +289,8 @@ int main(int argc, char **argv) {
const Com bottomLeft = Com(-1.5, -1.5);
const Com topRight = Com(1.5, 1.5);
- counts = computeCounts(W, H, bottomLeft, topRight);
- // Kernel().run_chunked(counts, W, H, bottomLeft, topRight, 42, 1 << 14);
+ // counts = computeCounts(W, H, bottomLeft, topRight);
+ Kernel().run_chunked(counts, W, H, bottomLeft, topRight, 42, 1 << 14);
// Kernel().run_all(counts, W, H, bottomLeft, topRight, 42);
writeCounts(W, H, counts, "out.txt");
diff --git a/aberth/util.cpp b/aberth/util.cpp
new file mode 100644
index 0000000..b8f6edc
--- /dev/null
+++ b/aberth/util.cpp
@@ -0,0 +1,27 @@
+#include "util.h"
+
+
+ostream& operator<<(ostream &os, const Poly &p) {
+ static const char *supers[10] = {
+ "⁰", "¹", "²", "³", "⁴", "⁵", "⁶", "⁷", "⁸", "⁹"
+ };
+
+ os << p[0];
+
+ for (int i = 1; i < (int)p.size(); i++) {
+ if (p[i] < 0) os << " - " << -p[i];
+ else if (p[i] > 0) os << " + " << p[i];
+ else continue;
+
+ os << "x";
+
+ if (i == 1) continue;
+
+ ostringstream ss;
+ ss << i;
+ string s = ss.str();
+ for (char c : s) os << supers[c - '0'];
+ }
+
+ return os;
+}
diff --git a/aberth/util.h b/aberth/util.h
new file mode 100644
index 0000000..c21a7bb
--- /dev/null
+++ b/aberth/util.h
@@ -0,0 +1,31 @@
+#pragma once
+
+#include <iostream>
+#include "defs.h"
+
+using namespace std;
+
+
+template <typename T>
+constexpr static T clearLowestBit(T value) {
+ return value & (value - 1);
+}
+
+template <typename T>
+constexpr static bool ispow2(T value) {
+ return clearLowestBit(value) == 0;
+}
+
+template <typename T>
+constexpr static T ceil2(T value) {
+ T value2 = clearLowestBit(value);
+ if (value2 == 0) return value;
+
+ while (true) {
+ value = value2;
+ value2 = clearLowestBit(value);
+ if (value2 == 0) return value << 1;
+ }
+}
+
+ostream& operator<<(ostream &os, const Poly &p);