aboutsummaryrefslogtreecommitdiff
path: root/aberth/aberth_kernel.fut
blob: f868ee521ef88825ea15748e123371fa66959e20 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import "lib/github.com/diku-dk/complex/complex"
import "lib/github.com/diku-dk/cpprandom/random"

module rand_engine = minstd_rand
module uniform_real = uniform_real_distribution f32 rand_engine

module cplx = mk_complex f32
type complex = cplx.complex

let N = 22i64
let PolyN = N + 1

type poly = [PolyN]f32

-- First element of pair steps fastest
let iota2 (n: i64) (m: i64): [](i64, i64) =
    flatten (map (\y -> map (\x -> (x, y)) (iota n)) (iota m))

let evaln_c (p: poly) (nterms: i64) (pt: complex): complex =
    foldr (\coef accum -> cplx.mk_re coef cplx.+ pt cplx.* accum)
          (cplx.mk_re p[nterms-1]) (take (nterms - 1) p)

let eval_c (p: poly) (pt: complex): complex = evaln_c p (length p) pt

let evaln_d (p: poly) (nterms: i64) (pt: f32): f32 =
    foldr (\coef accum -> coef + pt * accum)
          p[nterms-1] (take (nterms - 1) p)

let eval_d (p: poly) (pt: f32): f32 = evaln_d p (length p) pt

let derivative (p: poly): *poly = 
    map2 (\i v -> f32.i64 (i64.bool (i != PolyN - 1) * (i + 1)) * v)
         (0..<PolyN) (rotate 1 p)

-- Cauchy's bound: https://en.wikipedia.org/wiki/Geometrical_properties_of_polynomial_roots#Lagrange's_and_Cauchy's_bounds
let max_root_norm (p: poly): f32 =
    1 + f32.maximum (map (\coef -> f32.abs (coef / p[PolyN-1])) (init p))

module aberth = {
    type approx = [N]complex
    -- bound is 's' in the stop condition formulated at p.189-190 of
    -- https://link.springer.com/article/10.1007%2FBF02207694
    type context = {p: poly, deriv: poly, bound: poly, radius: f32}

    let gen_coord (r: f32) (rng: *rand_engine.rng): *(rand_engine.rng, f32) =
        uniform_real.rand (-r, r) rng

    let gen_coord_c (r: f32) (rng: *rand_engine.rng): (rand_engine.rng, complex) =
        let (rng, x) = gen_coord r rng
        let (rng, y) = gen_coord r rng
        in (rng, cplx.mk x y)

    let generate (ctx: context) (rng: *rand_engine.rng): *(rand_engine.rng, approx) =
        let rngs = rand_engine.split_rng N rng
        let (rngs, approx) = unzip (map (\rng -> gen_coord_c ctx.radius (copy rng)) rngs)
        let rng = rand_engine.join_rng rngs
        in (rng, approx)

    let compute_bound_poly (p: poly): *poly =
        map2 (\coef i -> f32.abs coef * f32.i64 (4 * i + 1)) p (0..<PolyN)

    let initialise (p: *poly): *context =
        let deriv = derivative p
        let bound = compute_bound_poly p
        let radius = max_root_norm p
        in {p, deriv, bound, radius}

    -- Jacobi-style step where the new elements are computed in parallel from
    -- the previous values
    let step (ctx: context) (approx: *approx): *(bool, approx) =
        let pvals = map (eval_c ctx.p) approx
        let derivvals = map (evaln_c ctx.deriv (PolyN - 1)) approx
        let quos = map2 (cplx./) pvals derivvals
        let sums = map (\i ->
                            reduce_comm (cplx.+) (cplx.mk_re 0.0)
                                (map (\j ->
                                        if i == j then cplx.mk_re 0.0
                                        else cplx.mk_re 1.0 cplx./
                                                (approx[i] cplx.- approx[j]))
                                     (0..<N)))
                       (0..<N)
        let offsets = map2 (\quo sum -> quo cplx./ (cplx.mk_re 1.0 cplx.- quo cplx.* sum))
                           quos sums
        let approx = map2 (cplx.-) approx offsets
        let svals = map (eval_d ctx.bound <-< cplx.mag) approx
        let conditions = map2 (\p s -> cplx.mag p <= 1e-5 * s) pvals svals
        let all_converged = all id conditions
        in (all_converged, approx)

    let iterate (ctx: context) (rng: *rand_engine.rng): (rand_engine.rng, *approx) =
        let (rng, approx) = generate ctx rng
        let (init_conv, approx) = step ctx approx
        let (rng, _, _, _, approx) =
            loop (rng, conv, tries, step_idx, approx) =
                    (rng, init_conv, 1, 1: i32, approx)
            while !conv
            do if step_idx + 1 > tries * 100
               then let (rng, approx) = generate ctx rng
                    let (conv, approx) = step ctx approx
                    in (rng, conv, tries + 1, 0, approx)
               else let (conv, approx) = step ctx approx
                    in (rng, conv, tries, step_idx + 1, approx)
        in (rng, approx)

    let aberth (p: *poly) (rng: *rand_engine.rng): *(rand_engine.rng, approx) =
        iterate (initialise p) rng
}

-- Set the constant coefficient to 1; nextDerbyshire will never change it
let init_derbyshire: poly =
    let res = [1] ++ replicate (PolyN - 1) (-1)
    in res :> poly

let next_derbyshire (p: *poly): *(bool, poly) =
    let (_, p, looped) =
        loop (i, p, cont) = (0, p, true)
        while cont && i < length p
        do if p[i] == -1
           then (i,     p with [i] =  1, false)
           else (i + 1, p with [i] = -1, true)
    in (looped, p)

let derbyshire_at_index (index: i32): *poly =
    let bitfield = (index << 1) + 1
    in tabulate PolyN (\i -> f32.i32 (i32.get_bit (i32.i64 i) bitfield * 2 - 1))


let calc_index (value: f32) (left: f32) (right: f32) (steps: i32): i32 =
    i32.f32 ((value - left) / (right - left) * (f32.i32 steps - 1) + 0.5)

let point_index
        (width: i32) (height: i32)
        (bottom_left: complex) (top_right: complex)
        (pt: complex)
        : i32 =
    -- Range for 'yi' is reversed because image coordinates go down in the y
    -- direction, while complex coordinates go up in the y direction
    let xi = calc_index (cplx.re pt) (cplx.re bottom_left) (cplx.re top_right) width
    let yi = calc_index (cplx.im pt) (cplx.im top_right) (cplx.im bottom_left) height
    in if 0 <= xi && xi < width && 0 <= yi && yi < height
        then width * yi + xi
        else -1

entry main_job
        (start_index: i32) (num_polys: i32)
        (width: i32) (height: i32)
        (left: f32) (top: f32) (right: f32) (bottom: f32)
        (seed: i32)
        : []i32 =
    -- Unnecessary to give each polynomial a different seed
    let rng = rand_engine.rng_from_seed [seed]
    let bottom_left = cplx.mk left bottom
    let top_right = cplx.mk right top
    let indices = flatten
            (map (\idx ->
                    let p = derbyshire_at_index idx
                    let (_, pts) = aberth.aberth p (copy rng)
                    in map (point_index width height bottom_left top_right) pts)
                 (start_index ..< start_index + num_polys))
    let filtered = filter (\i -> i != -1) indices
    in reduce_by_index (replicate (i64.i32 width * i64.i32 height) 0)
                       (+) 0
                       (map i64.i32 filtered)
                       (map (const 1) filtered)

entry main_all
        (width: i32) (height: i32)
        (left: f32) (top: f32) (right: f32) (bottom: f32)
        (seed: i32)
        : []i32 =
    main_job 0 (i32.i64 (1 << N)) width height left top right bottom seed

entry get_N: i32 = i32.i64 N