aboutsummaryrefslogtreecommitdiff
path: root/aberth/compute_host.cpp
blob: 0af20df0e6b90d88a1ce2c94745ae741e8219d3b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <thread>
#include <mutex>
#include <tuple>
#include <cassert>
#include "compute_host.h"
#include "host_aberth.h"
#include "polygen.h"
#include "util.h"

using namespace std;


vector<int> computeHost(int W, int H, Com bottomLeft, Com topRight) {
	constexpr const int numThreads = 4;
	static_assert(ispow2(numThreads));

	cerr << "Finding roots for " << PolyGen::numPolys() << " polynomials" << endl;

	vector<int> counts(W * H);
	mutex countsMutex;

	vector<PolyGen::Job> jobs = PolyGen::genJobs(numThreads);
	assert(jobs.size() == numThreads);

	vector<thread> threads(jobs.size());
	for (int i = 0; i < (int)jobs.size(); i++) {
		threads[i] = thread([W, H, &counts, &countsMutex, job = jobs[i], bottomLeft, topRight]() {
			auto calcIndex = [](double value, double left, double right, int steps) -> int {
				return (value - left) / (right - left) * (steps - 1) + 0.5;
			};
			auto calcPos = [W, H, bottomLeft, topRight, &calcIndex](Com z) -> pair<int, int> {
				return make_pair(
					calcIndex(z.real(), bottomLeft.real(), topRight.real(), W),
					calcIndex(z.imag(), topRight.imag(), bottomLeft.imag(), H)
				);
			};

			vector<int> localCounts(W * H);

			Poly poly = job.init;
			for (int i = 0; i < job.numItems; i++) {
				for (Com z : aberth(poly)) {
					int x, y;
					tie(x, y) = calcPos(z);
					if (0 <= x && x < W && 0 <= y && y < H) {
						localCounts[W * y + x]++;
					}
				}
				PolyGen::next(poly);
			}

			lock_guard<mutex> guard(countsMutex);
			for (int i = 0; i < W * H; i++) counts[i] += localCounts[i];
		});
	}

	for (thread &th : threads) th.join();

	return counts;
}