summaryrefslogtreecommitdiff
path: root/SC/Acc.hs
blob: 5ae2532a54313e37045c00fe92f202b04b97c92b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}
module SC.Acc where

import qualified Data.Array.Accelerate.AST as A
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape hiding (zip)
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type
import Data.Bifunctor
import qualified Data.Set as Set

import qualified Language.C as C
import SC.Defs
import SC.Exp
import SC.Monad


data Command
    = CChunk [C.FunDef]  -- ^ Emitted top-level function definitions
             [C.Stmt]    -- ^ Code to execute
             [C.Name]    -- ^ Array variables used
    | CAlloc [C.FunDef]  -- ^ Emitted top-level function definitions
             C.Type      -- ^ Element type of the allocated array
             C.Name      -- ^ Variable to store it in (newly declared!)
             C.StExpr    -- ^ Code that computes the array size
    | CKeepalive C.Name  -- ^ Never deallocate this
    | CDealloc C.Name
  deriving (Show)

insertDeallocs :: [Command] -> [Command]
insertDeallocs cmds =
    let collectable = Set.fromList [n | CAlloc _ _ n _ <- cmds]
                          `Set.difference` Set.fromList [n | CKeepalive n <- cmds]
    in fst $ foldr
         (\cmd (rest, done) -> case cmd of
              CChunk _ _ used ->
                  let todealloc = filter (\n -> n `Set.member` collectable &&
                                                n `Set.notMember` done)
                                         used
                  in (cmd : map CDealloc todealloc ++ rest
                     ,done `Set.union` Set.fromList todealloc)
              CAlloc _ _ name _
                | name `Set.notMember` done -> (rest, done)  -- unused alloc
                | otherwise -> (cmd : rest, Set.delete name done)
              CKeepalive _ -> (rest, done)  -- already handled above in @collectable@
              CDealloc _ -> error "insertDeallocs: CDealloc found")
         ([], mempty) cmds

compileCommands :: [Command] -> ([C.FunDef], [C.Stmt])
compileCommands [] = ([], [])
compileCommands (CChunk defs code _ : cmds) =
    bimap (defs ++) (code ++) (compileCommands cmds)
compileCommands (CAlloc defs typ name (C.StExpr szstmts szexpr) : cmds) =
    let allocstmt = C.SDecl (C.TPtr typ) name
                            (Just (C.ECall (C.Name "malloc") [C.EOp szexpr "*" (C.ESizeOf typ)]))
    in bimap (defs ++) ((szstmts ++ [allocstmt]) ++) (compileCommands cmds)
compileCommands (CDealloc name : cmds) =
    second ([C.SCall (C.Name "free") [C.EVar name]] ++) (compileCommands cmds)
compileCommands (CKeepalive _ : cmds) = compileCommands cmds


compileAcc' :: AVarEnv aenv -> TupANames t -> A.OpenAcc aenv t -> SC [Command]
compileAcc' aenv dest (A.OpenAcc acc) = compilePAcc' aenv dest acc

compilePAcc' :: AVarEnv aenv -> TupANames t -> A.PreOpenAcc A.OpenAcc aenv t -> SC [Command]
compilePAcc' aenv destnames = \case
    A.Alet lhs rhs body -> do
        (names, aenv') <- genVarsAEnv lhs aenv
        let sts1sh = [C.SDecl t n Nothing | TypedName t n <- fst (tupanamesList names)]
            sts1arr = [C.SDecl t n Nothing | TypedAName t n <- snd (tupanamesList names)]
        let cmds1 = [CChunk [] (sts1sh ++ sts1arr) []]
        cmds2 <- compileAcc' aenv names rhs
        cmds3 <- compileAcc' aenv' destnames body
        return (cmds1 ++ cmds2 ++ cmds3)

    A.Avar (Var _ idx)
      | ANArray destshnames destarrnames <- destnames -> do
          let (shnames, arrnames) = aveprj aenv idx
              sts = [C.SAsg destn (C.EVar srcn)
                    | (TypedName _ destn, TypedName _ srcn) <- zip (shnamesList destshnames) (shnamesList shnames)]
                    ++
                    [C.SAsg destn (C.EVar srcn)
                    | (destn, srcn) <- zipDestSrcNamesAA destarrnames arrnames]
              usedA = map (\(TypedAName _ n) -> n) (itupList arrnames)
          return [CChunk [] sts usedA]

    A.Apair a b
      | ANPair destnames1 destnames2 <- destnames -> do
          res1 <- compileAcc' aenv destnames1 a
          res2 <- compileAcc' aenv destnames2 b
          return (res1 ++ res2)

    A.Generate _ she fun@(A.Lam _ (A.Body (A.expType -> restype)))
      | ANArray destshnames destarrnames <- destnames -> do
          CompiledFun sheFD sheArgbuilder usedAshe <- compileExp aenv she
          CompiledFun funFD funArgbuilder usedAfun <- compileFun aenv fun
          tempnames <- genVars restype
          loops <- enumShapeNested destshnames $ \idxnames linidxexpr -> concat
                       [[C.SDecl t n Nothing | TypedName t n <- itupList tempnames]
                       ,[C.SCall (C.fundefName funFD)
                                 (funArgbuilder (itupEvars (fromShNames idxnames)) tempnames)]
                       ,[C.SStore arrname linidxexpr (C.EVar tempname)
                        | (arrname, tempname) <- zipDestSrcNamesAE destarrnames tempnames]]
          return . concat $
              [[CChunk [sheFD]
                       [C.SCall (C.fundefName sheFD)
                                (sheArgbuilder ITupIgnore (fromShNames destshnames))]
                       (concatMap (\(SomeArray _ ans) ->
                                       map (\(TypedAName _ n) -> n) (itupList ans))
                                  usedAshe)]
              ,[CAlloc [] eltty n (C.StExpr [] (computeSize destshnames))
               | TypedAName arrty n <- itupList destarrnames
               , let C.TPtr eltty = arrty]
              ,[CChunk [funFD]
                       loops
                       (map (\(TypedAName _ n) -> n)
                            (itupList destarrnames
                               ++ concatMap (\(SomeArray _ ans) -> itupList ans) usedAfun))]]

    _ -> throw "Unsupported Acc constructor"

-- | Returns an expression of type int64_t
computeSize :: ShNames sh -> C.Expr
computeSize ShZ = C.ELit "1LL"
computeSize (ShS ShZ n) = C.EVar n
computeSize (ShS ns n) = C.EOp (computeSize ns) "*" (C.EVar n)

-- | Given size variables and index variables, returns an expression of type int64_t
linearIndexExpr :: ShNames sh -> ShNames sh -> C.Expr
linearIndexExpr ShZ ShZ = C.ELit "1LL"
linearIndexExpr (ShS ShZ _) (ShS ShZ i) = C.EVar i
linearIndexExpr (ShS ns n) (ShS is i) =
    C.EOp (C.EOp (linearIndexExpr ns is) "*" (C.EVar n)) "+" (C.EVar i)

zipDestSrcNames :: ITup C.Name t -> ITup C.Name t -> [(C.Name, C.Name)]
zipDestSrcNames ITupIgnore _ = []
zipDestSrcNames _ ITupIgnore = error "Ignore in source names but not in destination names"
zipDestSrcNames (ITupSingle n) (ITupSingle n') = [(n, n')]
zipDestSrcNames (ITupPair a b) (ITupPair a' b') = zipDestSrcNames a a' ++ zipDestSrcNames b b'
zipDestSrcNames _ _ = error "wat"

zipDestSrcNamesAA :: ANames t -> ANames t -> [(C.Name, C.Name)]
zipDestSrcNamesAA ns1 ns2 =
    zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1)
                    (itupmap (\(TypedAName _ n) -> n) ns2)

zipDestSrcNamesAE :: ANames t -> Names t -> [(C.Name, C.Name)]
zipDestSrcNamesAE ns1 ns2 =
    zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1)
                    (itupmap (\(TypedName _ n) -> n) ns2)

-- | Given:
-- - shape size variables
-- - a function taking
--   - index variables
--   - an expression computing the linear index from the size and index variables
-- returns a nested loop statement where the loop body is given by the function.
enumShapeNested :: ShNames sh -> (ShNames sh -> C.Expr -> [C.Stmt]) -> SC [C.Stmt]
enumShapeNested sizenames fun = do
    idxnames <- genShNames (shNamesShape sizenames)
    let makeLoops :: ShNames sh -> ShNames sh -> [C.Stmt] -> [C.Stmt]
        makeLoops ShZ ShZ body = body
        makeLoops (ShS ns n) (ShS is i) body =
            makeLoops ns is [C.SFor (C.TInt C.B64) i (C.ELit "0") (C.EVar n) body]
    return (makeLoops sizenames idxnames (fun idxnames (linearIndexExpr sizenames idxnames)))

genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv')
genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env)
genVarsAEnv (LeftHandSideSingle (ArrayR sht ty)) env = do
    shnames <- genShNames sht
    names <- genAVars ty
    return (ANArray shnames names, AVEPush shnames names env)
genVarsAEnv (LeftHandSidePair lhs1 lhs2) env = do
    (n1, env1) <- genVarsAEnv lhs1 env
    (n2, env2) <- genVarsAEnv lhs2 env1
    return (ANPair n1 n2, env2)

genAVarsTup :: ArraysR t -> SC (TupANames t)
genAVarsTup TupRunit = return ANIgnore
genAVarsTup (TupRsingle (ArrayR sht ty)) = ANArray <$> genShNames sht <*> genAVars ty
genAVarsTup (TupRpair t1 t2) = ANPair <$> genAVarsTup t1 <*> genAVarsTup t2

genAVars :: TypeR t -> SC (ANames t)
genAVars TupRunit = return ITupIgnore
genAVars (TupRsingle ty) = genAVar ty
genAVars (TupRpair t1 t2) = ITupPair <$> genAVars t1 <*> genAVars t2

genShNames :: ShapeR sh -> SC (ShNames sh)
genShNames ShapeRz = return ShZ
genShNames (ShapeRsnoc sht) = do
    names <- genShNames sht
    name <- genName "n"
    return (ShS names name)

genAVar :: ScalarType t -> SC (ANames t)
genAVar ty = ITupSingle <$> (TypedAName <$> fmap C.TPtr (cvtType ty) <*> genName "a")