summaryrefslogtreecommitdiff
path: root/SC/Acc.hs
diff options
context:
space:
mode:
Diffstat (limited to 'SC/Acc.hs')
-rw-r--r--SC/Acc.hs124
1 files changed, 124 insertions, 0 deletions
diff --git a/SC/Acc.hs b/SC/Acc.hs
new file mode 100644
index 0000000..955c6da
--- /dev/null
+++ b/SC/Acc.hs
@@ -0,0 +1,124 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+module SC.Acc where
+
+import qualified Data.Array.Accelerate.AST as A
+import Data.Array.Accelerate.AST.LeftHandSide
+import Data.Array.Accelerate.AST.Var
+import Data.Array.Accelerate.Representation.Array
+import Data.Array.Accelerate.Representation.Type
+import Data.Array.Accelerate.Type
+import Data.Bifunctor
+import qualified Data.Set as Set
+
+import qualified Language.C as C
+import SC.Defs
+import SC.Monad
+
+
+data Command
+ = CChunk [C.FunDef] -- ^ Emitted top-level function definitions
+ [C.Stmt] -- ^ Code to execute
+ [C.Name] -- ^ Array variables used
+ | CAlloc [C.FunDef] -- ^ Emitted top-level function definitions
+ C.Type -- ^ Element type of the allocated array
+ C.Name -- ^ Variable to store it in
+ C.StExpr -- ^ Code that computes the array size
+ | CKeepalive C.Name -- ^ Never deallocate this
+ | CDealloc C.Name
+ deriving (Show)
+
+insertDeallocs :: [Command] -> [Command]
+insertDeallocs cmds =
+ let allocated = Set.fromList [n | CAlloc _ _ n _ <- cmds]
+ `Set.union` Set.fromList [n | CKeepalive n <- cmds]
+ in fst $ foldr
+ (\cmd (rest, done) -> case cmd of
+ CChunk _ _ used ->
+ let todealloc = filter (\n -> n `Set.member` allocated &&
+ n `Set.notMember` done)
+ used
+ in (cmd : map CDealloc todealloc ++ rest
+ ,done `Set.union` Set.fromList todealloc)
+ CAlloc _ _ name _
+ | name `Set.notMember` done -> (rest, done) -- unused alloc
+ | otherwise -> (cmd : rest, Set.delete name done)
+ CKeepalive _ -> (rest, done) -- already handled above in @allocated@
+ CDealloc _ -> error "insertDeallocs: CDealloc found")
+ ([], mempty) cmds
+
+compileCommands :: [Command] -> ([C.FunDef], [C.Stmt])
+compileCommands [] = ([], [])
+compileCommands (CChunk defs code _ : cmds) =
+ bimap (defs ++) (code ++) (compileCommands cmds)
+compileCommands (CAlloc defs typ name (C.StExpr szstmts szexpr) : cmds) =
+ let allocstmt = C.SDecl (C.TPtr typ) name
+ (Just (C.ECall (C.Name "malloc") [C.EOp szexpr "*" (C.ESizeOf typ)]))
+ in bimap (defs ++) ((szstmts ++ [allocstmt]) ++) (compileCommands cmds)
+compileCommands (CDealloc name : cmds) =
+ second ([C.SCall (C.Name "free") [C.EVar name]] ++) (compileCommands cmds)
+compileCommands (CKeepalive _ : cmds) = compileCommands cmds
+
+
+compileAcc' :: AVarEnv aenv -> TupANames t -> A.OpenAcc aenv t -> SC [Command]
+compileAcc' aenv dest (A.OpenAcc acc) = compilePAcc' aenv dest acc
+
+compilePAcc' :: AVarEnv aenv -> TupANames t -> A.PreOpenAcc A.OpenAcc aenv t -> SC [Command]
+compilePAcc' aenv destnames = \case
+ A.Alet lhs rhs body -> do
+ (names, aenv') <- genVarsAEnv lhs aenv
+ let sts1 = [C.SDecl t n Nothing | TypedAName t n <- itupList names]
+ let cmds1 = [CChunk [] sts1 []]
+ cmds2 <- compileAcc' aenv names rhs
+ cmds3 <- compileAcc' aenv' destnames body
+ return (cmds1 ++ cmds2 ++ cmds3)
+
+ A.Avar (Var _ idx) ->
+ return (Right ([], ITupSingle (C.EVar (aveprj aenv idx))))
+
+ A.Apair a b -> do
+ res1 <- compileAcc' aenv a
+ res2 <- compileAcc' aenv b
+ return (Left (\case
+ ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2
+ ITupIgnore -> []
+ ITupSingle _ -> error "wat"))
+
+ _ -> throw "Unsupported Acc constructor"
+ where
+ toStExprs :: TypeR t -> Either (ANames t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t)
+ toStExprs ty (Left fun) = do
+ names <- genAVars ty
+ let sts1 = fun names
+ return (sts1, itupmap (\(TypedName _ n) -> C.EVar n) names)
+ toStExprs _ (Right pair) = return pair
+
+ toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt]
+ toStoring (Left f) = f
+ toStoring (Right (sts, exs)) = (sts ++) . flip go exs
+ where
+ go :: Names t -> Exprs t -> [C.Stmt]
+ go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex]
+ go (ITupSingle _) _ = error "wat"
+ go ITupIgnore _ = []
+ go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2
+ go (ITupPair _ _) _ = error "wat"
+
+genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv')
+genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env)
+genVarsAEnv (LeftHandSideSingle (ArrayR _ ty)) env = do
+ name <- genName "a"
+ ty' <- cvtType ty
+ return (ITupSingle (TypedAName ty' name), AVEPush _ name env)
+genVarsAEnv (LeftHandSidePair lhs1 lhs2) env = do
+ (n1, env1) <- genVarsAEnv lhs1 env
+ (n2, env2) <- genVarsAEnv lhs2 env1
+ return (ITupPair n1 n2, env2)
+
+genAVars :: TypeR t -> SC (ANames t)
+genAVars TupRunit = return ITupIgnore
+genAVars (TupRsingle ty) = genAVar ty
+genAVars (TupRpair t1 t2) = ITupPair <$> genAVars t1 <*> genAVars t2
+
+genAVar :: ScalarType t -> SC (ANames t)
+genAVar ty = ITupSingle <$> (TypedAName <$> cvtType ty <*> genName "a")