summaryrefslogtreecommitdiff
path: root/compute/Compute.hs
blob: ee515c70c46d2ac520f709d0df1d76f8cd2db88b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
module Main (main) where

import qualified Data.Array.Accelerate as A
-- import qualified Data.Array.Accelerate.Interpreter as Interpreter
import qualified Data.Array.Accelerate.LLVM.Native as CPU
import qualified Data.ByteString.Lazy as BSL
import Data.Function ((&))
import qualified Data.Serialize as Ser
import qualified Data.Vector.Storable as VS
import Text.Read (readMaybe)
import System.Environment (getArgs)
import System.Exit (exitFailure)

import MandelHSlib


mandelstep :: (A.Ord a, A.Num a, A.Elt a) => A.Exp (a, a) -> A.Exp (a, a) -> A.Exp (a, a)
mandelstep (A.T2 cx cy) (A.T2 x y) = A.T2 (x * x - y * y + cx) (2 * x * y + cy)

mandeliter :: (A.Ord a, A.Num a, A.Elt a) => A.Exp (a, a) -> Int -> A.Exp Int
mandeliter (A.T2 cx cy) maxiter =
  let A.T3 nres _ _ =
        A.while (\(A.T3 n x y) -> n A.< A.constant maxiter A.&& x * x + y * y A.< 4)
                (\(A.T3 n x y) ->
                    let A.T2 x' y' = mandelstep (A.T2 cx cy) (A.T2 x y)
                    in A.T3 (n+1) x' y')
                (A.T3 0 cx cy)
  in nres

pixel :: (A.Num a, Num a, Fractional a, A.Floating a, A.ToFloating Int a, A.Elt a)
      => (Int, Int) -> A.Exp (a, a) -> A.Exp a -> A.Exp A.DIM2 -> A.Exp (a, a)
pixel (w, h) (A.T2 cx cy) cw (A.I2 iy ix) =
  let ch = cw * A.constant (fromIntegral h / fromIntegral w)
      ltx = cx - cw / 2
      lty = cy - ch / 2
  in A.T2 (ltx + cw / fromIntegral (w - 1) * A.toFloating ix)
          (lty + ch / fromIntegral (h - 1) * A.toFloating iy)

invpixel :: (A.Num a, Num a, Fractional a, A.RealFrac a, A.ToFloating Int a, A.Elt a)
         => (Int, Int) -> A.Exp (a, a) -> A.Exp a -> A.Exp (a, a) -> A.Exp (Maybe A.DIM2)
invpixel (w, h) (A.T2 cx cy) cw (A.T2 x y) =
  let ch = cw * A.constant (fromIntegral h / fromIntegral w)
      ltx = cx - cw / 2
      lty = cy - ch / 2
      ix = A.round $ (x - ltx) / cw * fromIntegral (w - 1)
      iy = A.round $ (y - lty) / ch * fromIntegral (h - 1)
  in A.cond (0 A.<= ix A.&& ix A.< A.constant w A.&&
             0 A.<= iy A.&& iy A.< A.constant h)
       (A.Just_ (A.I2 iy ix))
       A.Nothing_

mandelbrot :: (A.Num a, Num a, A.Ord a, Fractional a, A.Floating a, A.ToFloating Int a, A.Elt a)
           => (Int, Int) -> Int -> A.Acc (A.Scalar (a, a)) -> A.Acc (A.Scalar a)
           -> A.Acc (A.Matrix Int)
mandelbrot size@(w, h) maxiter (A.the -> cpos) (A.the -> cw) =
  A.generate (A.I2 (A.constant h) (A.constant w)) $ \idx ->
    mandeliter (pixel size cpos cw idx) maxiter

image :: (A.Num a, Num a, A.Ord a, Fractional a, A.RealFrac a, A.Floating a, A.ToFloating Int a, A.Elt a)
      => (Int, Int) -> Int -> A.Acc (A.Scalar (a, a)) -> A.Acc (A.Scalar a)
      -> A.Acc (A.Matrix a)
image size@(w, h) maxiter acpos@(A.the -> cpos) acw@(A.the -> cw) =
  let mbrot = mandelbrot (w, h) maxiter acpos acw
      mset = A.map (A.== A.constant maxiter) mbrot
      isize = A.I2 (A.constant h) (A.constant w)
      pixels = A.generate isize (pixel size cpos cw)
      A.T2 _ res =
        A.awhile (\(A.T2 (A.the -> iter) _) -> A.unit (iter A.< A.constant maxiter))
                 (\(A.T2 iter current) ->
                   let current' =
                         A.zipWith (\c (A.T2 z n) ->
                                      let z' = mandelstep c z
                                          n' = invpixel size cpos cw z' & A.match \case
                                                 A.Nothing_ -> n
                                                 A.Just_ z'pix ->
                                                   A.cond (mset A.! z'pix) (n+1) n
                                      in A.T2 z' n')
                                   pixels current
                   in A.T2 (A.map (+1) iter) current')
                 (A.T2 (A.unit 0)
                       (A.map (\z -> A.T2 z (0 :: A.Exp Int)) pixels))
      counts = A.map A.snd res
  in A.zipWith (/) (A.map A.toFloating counts) (A.map A.toFloating mbrot)


computeFractal :: (Int, Int) -> Int -> Fractal
computeFractal (w, h) maxiter =
  -- let arr = CPU.runN (\cpos cw ->
  --              A.map (\n -> A.toFloating n / A.toFloating (A.constant maxiter))
  --                    (mandelbrot (w, h) maxiter cpos cw))
  --               (A.fromList A.Z [(-0.5, 0.0)])
  --               (A.fromList A.Z [3.5])
  let arr = CPU.runN (\cpos cw -> image (w, h) maxiter cpos cw)
                (A.fromList A.Z [(-0.5, 0.0)])
                (A.fromList A.Z [3.5])
  in Fractal (w, h) maxiter (VS.fromList (A.toList arr))

usage :: String
usage =
  "Usage: mandelhs-compute <outfile.data> <width> <height> <maxiter>\n"

main :: IO ()
main = do
  -- let (w, h) = (4 * 2000, 4 * 1500)
  --     maxiter = 1024

  args <- getArgs
  case args of
    [outfile, wids, heis, maxiters] ->
      case (readMaybe wids, readMaybe heis, readMaybe maxiters) of
        (Just w, Just h, Just maxiter) ->
          BSL.writeFile outfile $
            Ser.runPutLazy (Ser.put (computeFractal (w, h) maxiter))
        _ -> putStr usage >> exitFailure

    _ -> putStr usage >> exitFailure