{-# LANGUAGE BlockArguments #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} module Main where import qualified Data.Array.Accelerate as A -- import qualified Data.Array.Accelerate.Interpreter as Interpreter import qualified Data.Array.Accelerate.LLVM.Native as CPU import qualified Data.ByteString.Lazy as BSL import Data.Function ((&)) import qualified Data.Serialize as Ser import Data.Word import qualified Data.Vector.Storable as VS import Text.Read (readMaybe) import System.Environment (getArgs) import System.Exit (exitFailure) import MandelHSlib mandelstep :: (A.Ord a, A.Num a, A.Elt a) => A.Exp (a, a) -> A.Exp (a, a) -> A.Exp (a, a) mandelstep (A.T2 cx cy) (A.T2 x y) = A.T2 (x * x - y * y + cx) (2 * x * y + cy) mandeliter :: (A.Ord a, A.Num a, A.Elt a) => A.Exp (a, a) -> Int -> A.Exp Int mandeliter (A.T2 cx cy) maxiter = let A.T3 nres _ _ = A.while (\(A.T3 n x y) -> n A.< A.constant maxiter A.&& x * x + y * y A.< 4) (\(A.T3 n x y) -> let A.T2 x' y' = mandelstep (A.T2 cx cy) (A.T2 x y) in A.T3 (n+1) x' y') (A.T3 0 cx cy) in nres pixel :: (A.Num a, Num a, Fractional a, A.Floating a, A.ToFloating Int a, A.Elt a) => (Int, Int) -> A.Exp (a, a) -> A.Exp a -> A.Exp A.DIM2 -> A.Exp (a, a) pixel (w, h) (A.T2 cx cy) cw (A.I2 iy ix) = let ch = cw * A.constant (fromIntegral h / fromIntegral w) ltx = cx - cw / 2 lty = cy - ch / 2 in A.T2 (ltx + cw / fromIntegral (w - 1) * A.toFloating ix) (lty + ch / fromIntegral (h - 1) * A.toFloating iy) invpixel :: (A.Num a, Num a, Fractional a, A.RealFrac a, A.ToFloating Int a, A.Elt a) => (Int, Int) -> A.Exp (a, a) -> A.Exp a -> A.Exp (a, a) -> A.Exp (Maybe A.DIM2) invpixel (w, h) (A.T2 cx cy) cw (A.T2 x y) = let ch = cw * A.constant (fromIntegral h / fromIntegral w) ltx = cx - cw / 2 lty = cy - ch / 2 ix = A.round $ (x - ltx) / cw * fromIntegral (w - 1) iy = A.round $ (y - lty) / ch * fromIntegral (h - 1) in A.cond (0 A.<= ix A.&& ix A.< A.constant w A.&& 0 A.<= iy A.&& iy A.< A.constant h) (A.Just_ (A.I2 iy ix)) A.Nothing_ mandelbrot :: (A.Num a, Num a, A.Ord a, Fractional a, A.Floating a, A.ToFloating Int a, A.Elt a) => (Int, Int) -> Int -> A.Acc (A.Scalar (a, a)) -> A.Acc (A.Scalar a) -> A.Acc (A.Matrix Int) mandelbrot size@(w, h) maxiter (A.the -> cpos) (A.the -> cw) = A.generate (A.I2 (A.constant h) (A.constant w)) $ \idx -> mandeliter (pixel size cpos cw idx) maxiter mandelSet :: (A.Num a, Num a, A.Ord a, Fractional a, A.Floating a, A.ToFloating Int a, A.Elt a) => (Int, Int) -> Int -> A.Acc (A.Scalar (a, a)) -> A.Acc (A.Scalar a) -> A.Acc (A.Matrix Bool) mandelSet size maxiter cpos cw = A.map (A.== A.constant maxiter) (mandelbrot size maxiter cpos cw) image :: (A.Num a, Num a, A.Ord a, Fractional a, A.RealFrac a, A.Floating a, A.ToFloating Int a, A.Elt a) => (Int, Int) -> Int -> A.Acc (A.Scalar (a, a)) -> A.Acc (A.Scalar a) -> A.Acc (A.Matrix a) image size@(w, h) maxiter acpos@(A.the -> cpos) acw@(A.the -> cw) = let mbrot = mandelbrot (w, h) maxiter acpos acw mset = A.map (A.== A.constant maxiter) mbrot isize = A.I2 (A.constant h) (A.constant w) pixels = A.generate isize (pixel size cpos cw) A.T2 _ res = A.awhile (\(A.T2 (A.the -> iter) _) -> A.unit (iter A.< A.constant maxiter)) (\(A.T2 iter current) -> let current' = A.zipWith (\c (A.T2 z n) -> let z' = mandelstep c z n' = invpixel size cpos cw z' & A.match \case A.Nothing_ -> n A.Just_ z'pix -> A.cond (mset A.! z'pix) (n+1) n in A.T2 z' n') pixels current in A.T2 (A.map (+1) iter) current') (A.T2 (A.unit 0) (A.map (\z -> A.T2 z (0 :: A.Exp Int)) pixels)) counts = A.map A.snd res in A.zipWith (/) (A.map A.toFloating counts) (A.map A.toFloating mbrot) colorscheme :: (Ord a, Floating a, RealFrac a) => a -> (Word8, Word8, Word8) colorscheme fraction -- | fraction == 1.0 = (0, 0, 0) | otherwise = let x = max 0 fraction ** 0.3 bg = 0.2 * (1 - curve x (-0.1) 0.2) in (tow8 $ bg + curve x 0 0.4 - 0.8 * curve x 0.6 1.0 ,tow8 $ bg + 0.8 * curve x 0.3 0.7 ,tow8 $ bg + curve x 0.6 1.0) where tow8 x = round (max 0 (min 255 (x * 255))) curve x start end | x <= start = 0 | x >= end = 1 | otherwise = sin (pi/(end-start) * (x - start) - pi/2) / 2 + 0.5 computeFractal :: (Ord a, Floating a, RealFrac a, A.Elt a, A.RealFrac a, A.Floating a, A.ToFloating Int a, VS.Storable a) => (Int, Int) -> Int -> Fractal a computeFractal (w, h) maxiter = -- let arr = CPU.runN (\cpos cw -> -- A.map (\n -> A.toFloating n / A.toFloating (A.constant maxiter)) -- (mandelbrot (w, h) maxiter cpos cw)) -- (A.fromList A.Z [(-0.5, 0.0)]) -- (A.fromList A.Z [3.5]) let arr = CPU.runN (\cpos cw -> image (w, h) maxiter cpos cw) (A.fromList A.Z [(-0.5, 0.0)]) (A.fromList A.Z [3.5]) in Fractal (w, h) maxiter (VS.fromList (A.toList arr)) usage :: String usage = "Usage: mandelhs-compute \n" main :: IO () main = do -- let (w, h) = (4 * 2000, 4 * 1500) -- maxiter = 1024 args <- getArgs case args of [outfile, wids, heis, maxiters] -> case (readMaybe wids, readMaybe heis, readMaybe maxiters) of (Just w, Just h, Just maxiter) -> BSL.writeFile outfile $ Ser.runPutLazy (Ser.put (computeFractal @Double (w, h) maxiter)) _ -> putStr usage >> exitFailure _ -> putStr usage >> exitFailure