diff options
Diffstat (limited to 'SC/Acc.hs')
-rw-r--r-- | SC/Acc.hs | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/SC/Acc.hs b/SC/Acc.hs new file mode 100644 index 0000000..955c6da --- /dev/null +++ b/SC/Acc.hs @@ -0,0 +1,124 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +module SC.Acc where + +import qualified Data.Array.Accelerate.AST as A +import Data.Array.Accelerate.AST.LeftHandSide +import Data.Array.Accelerate.AST.Var +import Data.Array.Accelerate.Representation.Array +import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Type +import Data.Bifunctor +import qualified Data.Set as Set + +import qualified Language.C as C +import SC.Defs +import SC.Monad + + +data Command + = CChunk [C.FunDef] -- ^ Emitted top-level function definitions + [C.Stmt] -- ^ Code to execute + [C.Name] -- ^ Array variables used + | CAlloc [C.FunDef] -- ^ Emitted top-level function definitions + C.Type -- ^ Element type of the allocated array + C.Name -- ^ Variable to store it in + C.StExpr -- ^ Code that computes the array size + | CKeepalive C.Name -- ^ Never deallocate this + | CDealloc C.Name + deriving (Show) + +insertDeallocs :: [Command] -> [Command] +insertDeallocs cmds = + let allocated = Set.fromList [n | CAlloc _ _ n _ <- cmds] + `Set.union` Set.fromList [n | CKeepalive n <- cmds] + in fst $ foldr + (\cmd (rest, done) -> case cmd of + CChunk _ _ used -> + let todealloc = filter (\n -> n `Set.member` allocated && + n `Set.notMember` done) + used + in (cmd : map CDealloc todealloc ++ rest + ,done `Set.union` Set.fromList todealloc) + CAlloc _ _ name _ + | name `Set.notMember` done -> (rest, done) -- unused alloc + | otherwise -> (cmd : rest, Set.delete name done) + CKeepalive _ -> (rest, done) -- already handled above in @allocated@ + CDealloc _ -> error "insertDeallocs: CDealloc found") + ([], mempty) cmds + +compileCommands :: [Command] -> ([C.FunDef], [C.Stmt]) +compileCommands [] = ([], []) +compileCommands (CChunk defs code _ : cmds) = + bimap (defs ++) (code ++) (compileCommands cmds) +compileCommands (CAlloc defs typ name (C.StExpr szstmts szexpr) : cmds) = + let allocstmt = C.SDecl (C.TPtr typ) name + (Just (C.ECall (C.Name "malloc") [C.EOp szexpr "*" (C.ESizeOf typ)])) + in bimap (defs ++) ((szstmts ++ [allocstmt]) ++) (compileCommands cmds) +compileCommands (CDealloc name : cmds) = + second ([C.SCall (C.Name "free") [C.EVar name]] ++) (compileCommands cmds) +compileCommands (CKeepalive _ : cmds) = compileCommands cmds + + +compileAcc' :: AVarEnv aenv -> TupANames t -> A.OpenAcc aenv t -> SC [Command] +compileAcc' aenv dest (A.OpenAcc acc) = compilePAcc' aenv dest acc + +compilePAcc' :: AVarEnv aenv -> TupANames t -> A.PreOpenAcc A.OpenAcc aenv t -> SC [Command] +compilePAcc' aenv destnames = \case + A.Alet lhs rhs body -> do + (names, aenv') <- genVarsAEnv lhs aenv + let sts1 = [C.SDecl t n Nothing | TypedAName t n <- itupList names] + let cmds1 = [CChunk [] sts1 []] + cmds2 <- compileAcc' aenv names rhs + cmds3 <- compileAcc' aenv' destnames body + return (cmds1 ++ cmds2 ++ cmds3) + + A.Avar (Var _ idx) -> + return (Right ([], ITupSingle (C.EVar (aveprj aenv idx)))) + + A.Apair a b -> do + res1 <- compileAcc' aenv a + res2 <- compileAcc' aenv b + return (Left (\case + ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2 + ITupIgnore -> [] + ITupSingle _ -> error "wat")) + + _ -> throw "Unsupported Acc constructor" + where + toStExprs :: TypeR t -> Either (ANames t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t) + toStExprs ty (Left fun) = do + names <- genAVars ty + let sts1 = fun names + return (sts1, itupmap (\(TypedName _ n) -> C.EVar n) names) + toStExprs _ (Right pair) = return pair + + toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt] + toStoring (Left f) = f + toStoring (Right (sts, exs)) = (sts ++) . flip go exs + where + go :: Names t -> Exprs t -> [C.Stmt] + go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex] + go (ITupSingle _) _ = error "wat" + go ITupIgnore _ = [] + go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2 + go (ITupPair _ _) _ = error "wat" + +genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv') +genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env) +genVarsAEnv (LeftHandSideSingle (ArrayR _ ty)) env = do + name <- genName "a" + ty' <- cvtType ty + return (ITupSingle (TypedAName ty' name), AVEPush _ name env) +genVarsAEnv (LeftHandSidePair lhs1 lhs2) env = do + (n1, env1) <- genVarsAEnv lhs1 env + (n2, env2) <- genVarsAEnv lhs2 env1 + return (ITupPair n1 n2, env2) + +genAVars :: TypeR t -> SC (ANames t) +genAVars TupRunit = return ITupIgnore +genAVars (TupRsingle ty) = genAVar ty +genAVars (TupRpair t1 t2) = ITupPair <$> genAVars t1 <*> genAVars t2 + +genAVar :: ScalarType t -> SC (ANames t) +genAVar ty = ITupSingle <$> (TypedAName <$> cvtType ty <*> genName "a") |