summaryrefslogtreecommitdiff
path: root/SC/Acc.hs
blob: 955c6da3c3217b504f2b4a876320bf1e4f7cd792 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
module SC.Acc where

import qualified Data.Array.Accelerate.AST as A
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type
import Data.Bifunctor
import qualified Data.Set as Set

import qualified Language.C as C
import SC.Defs
import SC.Monad


data Command
    = CChunk [C.FunDef]  -- ^ Emitted top-level function definitions
             [C.Stmt]    -- ^ Code to execute
             [C.Name]    -- ^ Array variables used
    | CAlloc [C.FunDef]  -- ^ Emitted top-level function definitions
             C.Type      -- ^ Element type of the allocated array
             C.Name      -- ^ Variable to store it in
             C.StExpr    -- ^ Code that computes the array size
    | CKeepalive C.Name  -- ^ Never deallocate this
    | CDealloc C.Name
  deriving (Show)

insertDeallocs :: [Command] -> [Command]
insertDeallocs cmds =
    let allocated = Set.fromList [n | CAlloc _ _ n _ <- cmds]
                        `Set.union` Set.fromList [n | CKeepalive n <- cmds]
    in fst $ foldr
         (\cmd (rest, done) -> case cmd of
              CChunk _ _ used ->
                  let todealloc = filter (\n -> n `Set.member` allocated &&
                                                n `Set.notMember` done)
                                         used
                  in (cmd : map CDealloc todealloc ++ rest
                     ,done `Set.union` Set.fromList todealloc)
              CAlloc _ _ name _
                | name `Set.notMember` done -> (rest, done)  -- unused alloc
                | otherwise -> (cmd : rest, Set.delete name done)
              CKeepalive _ -> (rest, done)  -- already handled above in @allocated@
              CDealloc _ -> error "insertDeallocs: CDealloc found")
         ([], mempty) cmds

compileCommands :: [Command] -> ([C.FunDef], [C.Stmt])
compileCommands [] = ([], [])
compileCommands (CChunk defs code _ : cmds) =
    bimap (defs ++) (code ++) (compileCommands cmds)
compileCommands (CAlloc defs typ name (C.StExpr szstmts szexpr) : cmds) =
    let allocstmt = C.SDecl (C.TPtr typ) name
                            (Just (C.ECall (C.Name "malloc") [C.EOp szexpr "*" (C.ESizeOf typ)]))
    in bimap (defs ++) ((szstmts ++ [allocstmt]) ++) (compileCommands cmds)
compileCommands (CDealloc name : cmds) =
    second ([C.SCall (C.Name "free") [C.EVar name]] ++) (compileCommands cmds)
compileCommands (CKeepalive _ : cmds) = compileCommands cmds


compileAcc' :: AVarEnv aenv -> TupANames t -> A.OpenAcc aenv t -> SC [Command]
compileAcc' aenv dest (A.OpenAcc acc) = compilePAcc' aenv dest acc

compilePAcc' :: AVarEnv aenv -> TupANames t -> A.PreOpenAcc A.OpenAcc aenv t -> SC [Command]
compilePAcc' aenv destnames = \case
    A.Alet lhs rhs body -> do
        (names, aenv') <- genVarsAEnv lhs aenv
        let sts1 = [C.SDecl t n Nothing | TypedAName t n <- itupList names]
        let cmds1 = [CChunk [] sts1 []]
        cmds2 <- compileAcc' aenv names rhs
        cmds3 <- compileAcc' aenv' destnames body
        return (cmds1 ++ cmds2 ++ cmds3)

    A.Avar (Var _ idx) ->
        return (Right ([], ITupSingle (C.EVar (aveprj aenv idx))))

    A.Apair a b -> do
        res1 <- compileAcc' aenv a
        res2 <- compileAcc' aenv b
        return (Left (\case
            ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2
            ITupIgnore -> []
            ITupSingle _ -> error "wat"))

    _ -> throw "Unsupported Acc constructor"
  where
    toStExprs :: TypeR t -> Either (ANames t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t)
    toStExprs ty (Left fun) = do
        names <- genAVars ty
        let sts1 = fun names
        return (sts1, itupmap (\(TypedName _ n) -> C.EVar n) names)
    toStExprs _ (Right pair) = return pair

    toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt]
    toStoring (Left f) = f
    toStoring (Right (sts, exs)) = (sts ++) . flip go exs
      where
        go :: Names t -> Exprs t -> [C.Stmt]
        go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex]
        go (ITupSingle _) _ = error "wat"
        go ITupIgnore _ = []
        go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2
        go (ITupPair _ _) _ = error "wat"

genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv')
genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env)
genVarsAEnv (LeftHandSideSingle (ArrayR _ ty)) env = do
    name <- genName "a"
    ty' <- cvtType ty
    return (ITupSingle (TypedAName ty' name), AVEPush _ name env)
genVarsAEnv (LeftHandSidePair lhs1 lhs2) env = do
    (n1, env1) <- genVarsAEnv lhs1 env
    (n2, env2) <- genVarsAEnv lhs2 env1
    return (ITupPair n1 n2, env2)

genAVars :: TypeR t -> SC (ANames t)
genAVars TupRunit = return ITupIgnore
genAVars (TupRsingle ty) = genAVar ty
genAVars (TupRpair t1 t2) = ITupPair <$> genAVars t1 <*> genAVars t2

genAVar :: ScalarType t -> SC (ANames t)
genAVar ty = ITupSingle <$> (TypedAName <$> cvtType ty <*> genName "a")