aboutsummaryrefslogtreecommitdiff
path: root/aberth
diff options
context:
space:
mode:
authortomsmeding <tom.smeding@gmail.com>2019-04-19 23:40:28 +0200
committertomsmeding <tom.smeding@gmail.com>2019-04-19 23:40:28 +0200
commit76e91597879672744ac07515f173efb5b3dcfc08 (patch)
tree68694b73340741e7d247bd99200bcce4502a6f26 /aberth
parentbe2bc936957d4fcbc2001ca48909bb2ce8d3f5c7 (diff)
Correct root bounds checking in futhark
Diffstat (limited to 'aberth')
-rw-r--r--aberth/aberth_kernel.fut17
1 files changed, 10 insertions, 7 deletions
diff --git a/aberth/aberth_kernel.fut b/aberth/aberth_kernel.fut
index 815c254..73a82ee 100644
--- a/aberth/aberth_kernel.fut
+++ b/aberth/aberth_kernel.fut
@@ -123,16 +123,19 @@ let derbyshire_at_index (index: i32): *poly =
in tabulate PolyN (\i -> f64.i32 (i32.get_bit i bitfield * 2 - 1))
+let calc_index (value: f64) (left: f64) (right: f64) (steps: i32): i32 =
+ t64 ((value - left) / (right - left) * (r64 steps - 1) + 0.5)
+
let point_index
(width: i32) (height: i32)
(top_left: complex) (bottom_right: complex)
(pt: complex)
: i32 =
- let x = (c64.re pt - c64.re top_left) / (c64.re bottom_right - c64.re top_left)
- let y = (c64.im pt - c64.im top_left) / (c64.im bottom_right - c64.im top_left)
- let xi = t64 (x * r64 (width - 1))
- let yi = t64 (y * r64 (height - 1))
- in width * yi + xi
+ let xi = calc_index (c64.re pt) (c64.re top_left) (c64.re bottom_right) width
+ let yi = calc_index (c64.im pt) (c64.im bottom_right) (c64.im top_left) height
+ in if 0 <= xi && xi < width && 0 <= yi && yi < height
+ then width * yi + xi
+ else -1
entry main_job
(start_index: i32) (num_polys: i32)
@@ -144,13 +147,13 @@ entry main_job
let rng = rand_engine.rng_from_seed [seed]
let top_left = c64.mk left top
let bottom_right = c64.mk right bottom
- let indices = flatten
+ let indices = filter (\i -> i != -1) (flatten
(map (\idx ->
let p = derbyshire_at_index idx
let (_, pts) = aberth.aberth p rng
let indices = map (point_index width height top_left bottom_right) pts
in indices)
- (start_index ..< start_index + num_polys))
+ (start_index ..< start_index + num_polys)))
in reduce_by_index (replicate (width * height) 0) (+) 0 indices (replicate (length indices) 1)
entry main_all