aboutsummaryrefslogtreecommitdiff
path: root/aberth/kernel.cpp
blob: 92d2d2040f725ebe5e2ab483b0455cbdf0501b4e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include <iostream>
#include <cassert>
#include "kernel.h"

extern "C" {
#include "aberth_kernel.h"
}

using namespace std;


void Kernel::check_ret(int ret) {
    if (ret != 0) {
        char *str = futhark_context_get_error(ctx);
        cerr << str << endl;
        free(str);
        exit(1);
    }
};

Kernel::Kernel() {
    futhark_context_config *config = futhark_context_config_new();
    // futhark_context_config_select_device_interactively(config);
    // futhark_context_config_set_debugging(config, 1);

    ctx = futhark_context_new(config);

    futhark_context_config_free(config);

    check_ret(futhark_entry_get_N(ctx, &N));
    // The 31 check is to not exceed 32-bit signed integer bounds
    assert(N >= 1 && N < 31);
}

Kernel::~Kernel() {
    futhark_context_free(ctx);
}

void Kernel::run_job(
        vector<int32_t> &dest,
        int32_t width, int32_t height,
        Com bottomLeft, Com topRight,
        int32_t seed,
        int32_t start_index, int32_t poly_count) {

    futhark_i32_1d *dest_arr;

    check_ret(futhark_entry_main_job(
            ctx, &dest_arr,
            start_index, poly_count,
            width, height,
            bottomLeft.real(), topRight.imag(),
            topRight.real(), bottomLeft.imag(),
            seed));

    check_ret(futhark_context_sync(ctx));

    int64_t shape = futhark_shape_i32_1d(ctx, dest_arr)[0];
    assert(shape == width * height);

    dest.resize(width * height);
    check_ret(futhark_values_i32_1d(ctx, dest_arr, dest.data()));
    check_ret(futhark_free_i32_1d(ctx, dest_arr));
}

void Kernel::run_all(
        vector<int32_t> &dest,
        int32_t width, int32_t height,
        Com bottomLeft, Com topRight,
        int32_t seed) {

    run_job(dest, width, height, bottomLeft, topRight, seed, 0, 1 << N);
}

void Kernel::run_chunked(
        vector<int32_t> &dest,
        int32_t width, int32_t height,
        Com bottomLeft, Com topRight,
        int32_t seed,
        int32_t chunk_size) {

    dest.clear();
    dest.resize(width * height);

    int32_t start_index = 0;
    int32_t total = 1 << N;

    int32_t njobs = (total + chunk_size - 1) / chunk_size;
    cerr << "Running " << njobs << " jobs of size " << chunk_size << endl;
    cerr << string(njobs, '.') << '\r';

    while (start_index < total) {
        int32_t num_polys = min(chunk_size, total - start_index);

        vector<int32_t> output;
        run_job(
            output,
            width, height, bottomLeft, topRight, seed,
            start_index, num_polys);

        for (int i = 0; i < width * height; i++) {
            dest[i] += output[i];
        }

        start_index += num_polys;

        cerr << '|';
    }

    cerr << endl;
}