diff options
author | Tom Smeding <tom@tomsmeding.com> | 2021-10-10 19:55:59 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2021-10-10 19:55:59 +0200 |
commit | 1640830bf5dc0630481e698512064215eb3e8249 (patch) | |
tree | 229b5666508e1152b5fff77733e48539591af0ab | |
parent | ff220bfb4c4c67f666a4701f2514d8de432f1e9a (diff) |
-rw-r--r-- | .gitignore | 3 | ||||
-rw-r--r-- | Language/C.hs | 4 | ||||
-rw-r--r-- | Language/C/Print.hs | 23 | ||||
-rw-r--r-- | SC/Acc.hs | 2 | ||||
-rw-r--r-- | SC/Exp.hs | 25 | ||||
-rw-r--r-- | accelerate-sc.cabal | 15 | ||||
-rw-r--r-- | output.c (renamed from test.c) | 0 | ||||
-rw-r--r-- | test/Examples/Mandel.hs | 60 | ||||
-rw-r--r-- | test/Examples/Mandel/Main.hs | 14 | ||||
-rw-r--r-- | test/Examples/Test.hs | 10 | ||||
-rw-r--r-- | test/Examples/Utils/PPM.hs | 18 | ||||
-rw-r--r-- | test/Main.hs | 20 |
12 files changed, 187 insertions, 7 deletions
@@ -1,2 +1,3 @@ dist-newstyle/ -test +output +mandel.ppm diff --git a/Language/C.hs b/Language/C.hs index 9a65115..8b2a9d5 100644 --- a/Language/C.hs +++ b/Language/C.hs @@ -44,6 +44,9 @@ data Stmt | SCall Name [Expr] -- | @SFor ty i lo hi body@: @for (ty i = lo; i < hi; i++) body@ | SFor Type Name Expr Expr [Stmt] + -- | @SWhile pre cond post@: @while (true) { pre; if (!cond) break; post; }@ + -- Special case if @pre == []@: @while (cond) { post; } + | SWhile [Stmt] Expr [Stmt] | SIf Expr [Stmt] [Stmt] deriving (Show, Eq) @@ -56,6 +59,7 @@ data Expr | EIndex Name Expr | EPtrTo Expr | ESizeOf Type + | ECast Type Expr deriving (Show, Eq) diff --git a/Language/C/Print.hs b/Language/C/Print.hs index 5852601..3d7688c 100644 --- a/Language/C/Print.hs +++ b/Language/C/Print.hs @@ -50,10 +50,14 @@ printType = printString . showType printBits B64 = "64" printBlock :: [Stmt] -> PrintS -printBlock ss = - printString "{\n " - % addIndent 2 (intercalates "\n" (map printStmt ss)) - % printString "\n}" +printBlock ss = printBlock' (intercalates "\n" (map printStmt ss)) + +printBlock' :: PrintS -> PrintS +printBlock' body = + getIndent $ \d -> + printString ("{\n" ++ replicate (d + 2) ' ') + % addIndent 2 body + % printString "\n}" printStmt :: Stmt -> PrintS printStmt (SDecl ty name rhs) = @@ -75,6 +79,13 @@ printStmt (SFor ty name lo hi body) = % printString "; " % printName name % printString "++) " % printBlock body +printStmt (SWhile [] cond body) = + printString "while (" % printExpr cond % printString ") " % printBlock body +printStmt (SWhile pre cond post) = + printString "while (true) " + % printBlock' (intercalates "\n" (map printStmt pre) + % printString "\nif (" % printExpr cond % printString ") break;\n" + % intercalates "\n" (map printStmt post)) printStmt (SIf e b1 b2) = printString "if (" % printExpr e % printString ") " % printBlock b1 % printString " else " % printBlock b2 @@ -93,6 +104,7 @@ printExpr (ECall name args) = printExpr (EIndex name e) = printName name % printString "[" % printExpr e % printString "]" printExpr (EPtrTo e) = printString "&(" % printExpr e % printString ")" printExpr (ESizeOf t) = printString "(sizeof (" % printType t % printString "))" +printExpr (ECast t e) = printString "(" % printType t % printString ")(" % printExpr e % printString ")" addIndent :: Int -> PrintS -> PrintS @@ -101,6 +113,9 @@ addIndent plusd f d = f (d + plusd) withIndent :: Int -> PrintS -> PrintS withIndent d f _ = f d +getIndent :: (Int -> PrintS) -> PrintS +getIndent f d = f d d + intercalates :: String -> [PrintS] -> PrintS intercalates sep l = foldr (%) (\_ -> id) $ intersperse (printString sep) l @@ -88,6 +88,8 @@ compilePAcc' aenv destnames = \case usedA = map (\(TypedAName _ n) -> n) (itupList arrnames) return [CChunk [] sts usedA] + A.Anil -> return [] + A.Apair a b | ANPair destnames1 destnames2 <- destnames -> do res1 <- compileAcc' aenv destnames1 a @@ -11,7 +11,7 @@ import Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type --- import Debug.Trace +import Debug.Trace import Debug import qualified Language.C as C @@ -46,7 +46,7 @@ compileFun aenv (A.Lam lhs (A.Body body)) = do outnames <- itupmap (\(TypedName t n) -> TypedName (C.TPtr t) n) <$> genVars (A.expType body) ((_tree, usedA), res) <- compileExp' aenv env body - -- traceM ("Compiled expression:\n" ++ prettyTree " " " " tree) + traceM ("Compiled expression:\n" ++ prettyTree " " " " _tree) (sts1, retexprs) <- toStExprs (A.expType body) res let sts2 = genoutstores outnames retexprs arrayarguments = @@ -116,6 +116,15 @@ compileExp' aenv env = \case ITupIgnore -> [] ITupSingle _ -> error "wat")) + A.While (A.Lam condlhs (A.Body condexp)) (A.Lam bodylhs (A.Body bodyexp)) initexp -> do + names <- genVars (lhsToTupR condlhs) + let condenv = pushVarsLHS condlhs names env + bodyenv = pushVarsLHS condlhs names env + ((tree1, usedA1), res1) <- compileExp' aenv env condexp + ((tree2, usedA2), res2) <- compileExp' aenv env bodyexp + ((tree3, usedA3), res3) <- compileExp' aenv env initexp + undefined + A.Const ty x | Just str <- showExpConst ty x -> return ((Leaf ("Const (" ++ str ++ ")"), []), Right ([], ITupSingle (C.ELit str))) @@ -125,6 +134,11 @@ compileExp' aenv env = \case A.PrimApp (A.PrimMul _) e -> binary aenv env "*" e A.PrimApp (A.PrimQuot _) e -> binary aenv env "/" e A.PrimApp (A.PrimRem _) e -> binary aenv env "%" e + A.PrimApp (A.PrimFDiv _) e -> binary aenv env "/" e + A.PrimApp (A.PrimLog TypeFloat) e -> unary aenv env "log" (C.ECall (C.Name "logf") . pure) e + A.PrimApp (A.PrimLog TypeDouble) e -> unary aenv env "log" (C.ECall (C.Name "log") . pure) e + A.PrimApp (A.PrimToFloating _ TypeFloat) e -> unary aenv env "cast float" (C.ECast C.TFloat) e + A.PrimApp (A.PrimToFloating _ TypeDouble) e -> unary aenv env "cast double" (C.ECast C.TDouble) e A.PrimApp op _ -> throw $ "Unsupported Exp primitive operator: " ++ showPrimFun op A.Shape (Var _ idx) -> @@ -174,6 +188,13 @@ compileExp' aenv env = \case toStExprs (A.expType e') res return ((Node ("binary " ++ show op) [tree], usedA), Right (sts, ITupSingle (C.EOp e1 op e2))) + unary :: AVarEnv aenv -> VarEnv env -> String -> (C.Expr -> C.Expr) -> A.OpenExp env aenv a + -> SC ((Tree, [SomeArray]), Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t)) + unary aenv' env' name op e' = do + ((tree, usedA), res) <- compileExp' aenv' env' e' + (sts, ITupSingle e1) <- toStExprs (A.expType e') res + return ((Node ("unary " ++ name) [tree], usedA), Right (sts, ITupSingle (op e1))) + toStExprs :: TypeR t -> Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t) toStExprs ty (Left fun) = do names <- genVars ty diff --git a/accelerate-sc.cabal b/accelerate-sc.cabal index 378bb2a..5a1923f 100644 --- a/accelerate-sc.cabal +++ b/accelerate-sc.cabal @@ -30,3 +30,18 @@ library hs-source-dirs: . default-language: Haskell2010 ghc-options: -Wall -O2 + +test-suite test + type: exitcode-stdio-1.0 + main-is: Main.hs + other-modules: + Examples.Mandel + Examples.Test + Examples.Utils.PPM + hs-source-dirs: test + ghc-options: -Wall -O2 -threaded + build-depends: + accelerate-sc, + base >= 4.13 && < 4.15, + accelerate ^>= 1.3.0.0 + default-language: Haskell2010 diff --git a/test/Examples/Mandel.hs b/test/Examples/Mandel.hs new file mode 100644 index 0000000..adb116b --- /dev/null +++ b/test/Examples/Mandel.hs @@ -0,0 +1,60 @@ +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} +module Examples.Mandel (afun) where + +import Prelude () +import qualified Prelude as P +import Data.Array.Accelerate + + +type Dims = (Int, Int) +type Pos = (Double, Double) +type Viewport = (Dims -- image size + ,Pos -- midpoint + ,Double) -- complex width of the viewport + +type RGB = (Word8, Word8, Word8) + +-- Arguments: viewport and maxiter +-- Result: image in row-major order +afun :: Acc (Scalar (Viewport, Int)) -> Acc (Matrix RGB) +afun (the -> T2 viewport maxiter) = mandel viewport maxiter (clrBasic maxiter) + +clrBasic :: Exp Int -> Exp Int -> Exp RGB +clrBasic maxiterI nI = + let maxiter = log (toFloating maxiterI) + n = log (toFloating nI) + in cond (n == maxiter) + (T3 0 0 0) + (let r = slope (Just (maxiter / 4)) Nothing n maxiter + g = slope Nothing Nothing n maxiter + b = slope Nothing (Just (maxiter * 3 / 4)) n maxiter + in T3 r g b) + where + slope :: Maybe (Exp Double) -> Maybe (Exp Double) -> Exp Double -> Exp Double -> Exp Word8 + slope mlo mhi x m = + (P.maybe ($ 0) (\lo' -> max 0 . ($ lo')) mlo) $ \lo -> + (P.maybe ($ m) (\hi' -> min 255 . ($ hi')) mhi) $ \hi -> + fromIntegral @Int @Word8 $ round @Double @Int $ + max 0 $ min 255 $ (x - lo) * 255 / (hi - lo) + +mandel :: Exp Viewport -> Exp Int -> (Exp Int -> Exp RGB) -> Acc (Matrix RGB) +mandel (T3 (T2 w h) (T2 cx cy) cw) maxiter clrscheme = + generate (I2 h w) $ \(I2 yi xi) -> + let minx = cx - cw / 2 + ch = toFloating h / toFloating w * cw + maxy = cy + ch / 2 + x = minx + toFloating xi / (toFloating w - 1) * cw + y = maxy - toFloating yi / (toFloating h - 1) * ch + in clrscheme (mandeliter x y maxiter) + +mandeliter :: Exp Double -> Exp Double -> Exp Int -> Exp Int +mandeliter x y maxiter = + let T5 _ _ _ _ n = + while (\(T5 _ _ a2 b2 i) -> a2 + b2 < 4 && i < maxiter) + (\(T5 a b a2 b2 i) -> + let a' = a2 - b2 + x + b' = 2 * a * b + y + in T5 a' b' (a'*a') (b'*b') (i + 1)) + (T5 x y (x*x) (y*y) 0) + in n diff --git a/test/Examples/Mandel/Main.hs b/test/Examples/Mandel/Main.hs new file mode 100644 index 0000000..f1d49d1 --- /dev/null +++ b/test/Examples/Mandel/Main.hs @@ -0,0 +1,14 @@ +module Examples.Mandel.Main where + +import qualified Data.Array.Accelerate as A +import qualified Data.Array.Accelerate.Interpreter as I + +import qualified Examples.Mandel as Mandel +import Examples.Utils.PPM + + +main :: IO () +main = do + let viewport = ((640, 480), (-0.5, 0.0), 3.0) + img = I.run1 Mandel.afun (A.fromList A.Z [(viewport, 200)]) + ppmWrite img "mandel.ppm" diff --git a/test/Examples/Test.hs b/test/Examples/Test.hs new file mode 100644 index 0000000..f3df311 --- /dev/null +++ b/test/Examples/Test.hs @@ -0,0 +1,10 @@ +module Examples.Test (afun) where + +import Data.Array.Accelerate + + +afun :: Acc (Matrix Int, Vector (Int, Int)) + -> Acc (Matrix Int) +afun (T2 a b) = generate (I2 2 3) (\(I2 i j) -> + let T2 x y = b ! I1 i + in i * j + a ! I2 i j + x * y) diff --git a/test/Examples/Utils/PPM.hs b/test/Examples/Utils/PPM.hs new file mode 100644 index 0000000..fe8751d --- /dev/null +++ b/test/Examples/Utils/PPM.hs @@ -0,0 +1,18 @@ +module Examples.Utils.PPM where + +import qualified Data.Array.Accelerate as A +import Data.Word + + +type RGB = (Word8, Word8, Word8) + +ppmWrite :: A.Matrix RGB -> FilePath -> IO () +ppmWrite img fp = do + let A.Z A.:. h A.:. w = A.arrayShape img + line y = unwords $ concat [[show r, show g, show b] | x <- [0 .. w - 1], let (r, g, b) = A.indexArray img (A.Z A.:. y A.:. x)] + contents = unlines $ + ["P3" + ,show w ++ " " ++ show h + ,"255"] + ++ [line y | y <- [0 .. h - 1]] + writeFile fp contents diff --git a/test/Main.hs b/test/Main.hs new file mode 100644 index 0000000..19105dd --- /dev/null +++ b/test/Main.hs @@ -0,0 +1,20 @@ +module Main where + +import qualified Data.Array.Accelerate as A +import qualified Data.Array.Accelerate.Interpreter as I +import System.Exit + +import qualified Data.Array.Accelerate.C as C + +import qualified Examples.Mandel as Mandel +import Examples.Utils.PPM + + +main :: IO () +main = do + let viewport = ((640, 480), (-0.5, 0.0), 3.0) + img = I.run1 Mandel.afun (A.fromList A.Z [(viewport, 200)]) + case C.translateAcc "mandelkernel" Mandel.afun of + Left err -> die err + Right (code, _, _) -> writeFile "mandel-out.c" code + ppmWrite img "mandel.ppm" |