1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
module SC.Acc where
import qualified Data.Array.Accelerate.AST as A
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape hiding (zip)
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type
import Data.Bifunctor
import qualified Data.Set as Set
import qualified Language.C as C
import SC.Defs
import SC.Monad
data Command
= CChunk [C.FunDef] -- ^ Emitted top-level function definitions
[C.Stmt] -- ^ Code to execute
[C.Name] -- ^ Array variables used
| CAlloc [C.FunDef] -- ^ Emitted top-level function definitions
C.Type -- ^ Element type of the allocated array
C.Name -- ^ Variable to store it in
C.StExpr -- ^ Code that computes the array size
| CKeepalive C.Name -- ^ Never deallocate this
| CDealloc C.Name
deriving (Show)
insertDeallocs :: [Command] -> [Command]
insertDeallocs cmds =
let collectable = Set.fromList [n | CAlloc _ _ n _ <- cmds]
`Set.difference` Set.fromList [n | CKeepalive n <- cmds]
in fst $ foldr
(\cmd (rest, done) -> case cmd of
CChunk _ _ used ->
let todealloc = filter (\n -> n `Set.member` collectable &&
n `Set.notMember` done)
used
in (cmd : map CDealloc todealloc ++ rest
,done `Set.union` Set.fromList todealloc)
CAlloc _ _ name _
| name `Set.notMember` done -> (rest, done) -- unused alloc
| otherwise -> (cmd : rest, Set.delete name done)
CKeepalive _ -> (rest, done) -- already handled above in @collectable@
CDealloc _ -> error "insertDeallocs: CDealloc found")
([], mempty) cmds
compileCommands :: [Command] -> ([C.FunDef], [C.Stmt])
compileCommands [] = ([], [])
compileCommands (CChunk defs code _ : cmds) =
bimap (defs ++) (code ++) (compileCommands cmds)
compileCommands (CAlloc defs typ name (C.StExpr szstmts szexpr) : cmds) =
let allocstmt = C.SDecl (C.TPtr typ) name
(Just (C.ECall (C.Name "malloc") [C.EOp szexpr "*" (C.ESizeOf typ)]))
in bimap (defs ++) ((szstmts ++ [allocstmt]) ++) (compileCommands cmds)
compileCommands (CDealloc name : cmds) =
second ([C.SCall (C.Name "free") [C.EVar name]] ++) (compileCommands cmds)
compileCommands (CKeepalive _ : cmds) = compileCommands cmds
compileAcc' :: AVarEnv aenv -> TupANames t -> A.OpenAcc aenv t -> SC [Command]
compileAcc' aenv dest (A.OpenAcc acc) = compilePAcc' aenv dest acc
compilePAcc' :: AVarEnv aenv -> TupANames t -> A.PreOpenAcc A.OpenAcc aenv t -> SC [Command]
compilePAcc' aenv destnames = \case
A.Alet lhs rhs body -> do
(names, aenv') <- genVarsAEnv lhs aenv
let sts1sh = [C.SDecl t n Nothing | TypedName t n <- fst (tupanamesList names)]
sts1arr = [C.SDecl t n Nothing | TypedAName t n <- snd (tupanamesList names)]
let cmds1 = [CChunk [] (sts1sh ++ sts1arr) []]
cmds2 <- compileAcc' aenv names rhs
cmds3 <- compileAcc' aenv' destnames body
return (cmds1 ++ cmds2 ++ cmds3)
A.Avar (Var _ idx)
| ANArray destshnames destarrnames <- destnames -> do
let (shnames, arrnames) = aveprj aenv idx
sts = [C.SAsg destn (C.EVar srcn)
| (TypedName _ destn, TypedName _ srcn) <- zip (shnamesList destshnames) (shnamesList shnames)]
++
[C.SAsg destn (C.EVar srcn)
| (destn, srcn) <- zipDestSrcNames destarrnames arrnames]
usedA = map (\(TypedAName _ n) -> n) (itupList arrnames)
return [CChunk [] sts usedA]
A.Apair a b
| ANPair destnames1 destnames2 <- destnames -> do
res1 <- compileAcc' aenv destnames1 a
res2 <- compileAcc' aenv destnames2 b
return (res1 ++ res2)
_ -> throw "Unsupported Acc constructor"
zipDestSrcNames :: ANames e -> ANames e -> [(C.Name, C.Name)]
zipDestSrcNames ITupIgnore _ = []
zipDestSrcNames _ ITupIgnore = error "Ignore in source names where there is none in the destination names"
zipDestSrcNames (ITupSingle (TypedAName _ n)) (ITupSingle (TypedAName _ n')) = [(n, n')]
zipDestSrcNames (ITupPair a b) (ITupPair a' b') = zipDestSrcNames a a' ++ zipDestSrcNames b b'
zipDestSrcNames _ _ = error "wat"
genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv')
genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env)
genVarsAEnv (LeftHandSideSingle (ArrayR sht ty)) env = do
shnames <- genShNames sht
names <- genAVars ty
return (ANArray shnames names, AVEPush shnames names env)
genVarsAEnv (LeftHandSidePair lhs1 lhs2) env = do
(n1, env1) <- genVarsAEnv lhs1 env
(n2, env2) <- genVarsAEnv lhs2 env1
return (ANPair n1 n2, env2)
genAVars :: TypeR t -> SC (ANames t)
genAVars TupRunit = return ITupIgnore
genAVars (TupRsingle ty) = genAVar ty
genAVars (TupRpair t1 t2) = ITupPair <$> genAVars t1 <*> genAVars t2
genShNames :: ShapeR sh -> SC (ShNames sh)
genShNames ShapeRz = return ShZ
genShNames (ShapeRsnoc sht) = do
name <- genName "n"
names <- genShNames sht
return (ShS name names)
genAVar :: ScalarType t -> SC (ANames t)
genAVar ty = ITupSingle <$> (TypedAName <$> cvtType ty <*> genName "a")
|