{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} module SC.Acc where import qualified Data.Array.Accelerate.AST as A import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import Data.Bifunctor import qualified Data.Set as Set import qualified Language.C as C import SC.Defs import SC.Monad data Command = CChunk [C.FunDef] -- ^ Emitted top-level function definitions [C.Stmt] -- ^ Code to execute [C.Name] -- ^ Array variables used | CAlloc [C.FunDef] -- ^ Emitted top-level function definitions C.Type -- ^ Element type of the allocated array C.Name -- ^ Variable to store it in C.StExpr -- ^ Code that computes the array size | CKeepalive C.Name -- ^ Never deallocate this | CDealloc C.Name deriving (Show) insertDeallocs :: [Command] -> [Command] insertDeallocs cmds = let allocated = Set.fromList [n | CAlloc _ _ n _ <- cmds] `Set.union` Set.fromList [n | CKeepalive n <- cmds] in fst $ foldr (\cmd (rest, done) -> case cmd of CChunk _ _ used -> let todealloc = filter (\n -> n `Set.member` allocated && n `Set.notMember` done) used in (cmd : map CDealloc todealloc ++ rest ,done `Set.union` Set.fromList todealloc) CAlloc _ _ name _ | name `Set.notMember` done -> (rest, done) -- unused alloc | otherwise -> (cmd : rest, Set.delete name done) CKeepalive _ -> (rest, done) -- already handled above in @allocated@ CDealloc _ -> error "insertDeallocs: CDealloc found") ([], mempty) cmds compileCommands :: [Command] -> ([C.FunDef], [C.Stmt]) compileCommands [] = ([], []) compileCommands (CChunk defs code _ : cmds) = bimap (defs ++) (code ++) (compileCommands cmds) compileCommands (CAlloc defs typ name (C.StExpr szstmts szexpr) : cmds) = let allocstmt = C.SDecl (C.TPtr typ) name (Just (C.ECall (C.Name "malloc") [C.EOp szexpr "*" (C.ESizeOf typ)])) in bimap (defs ++) ((szstmts ++ [allocstmt]) ++) (compileCommands cmds) compileCommands (CDealloc name : cmds) = second ([C.SCall (C.Name "free") [C.EVar name]] ++) (compileCommands cmds) compileCommands (CKeepalive _ : cmds) = compileCommands cmds compileAcc' :: AVarEnv aenv -> TupANames t -> A.OpenAcc aenv t -> SC [Command] compileAcc' aenv dest (A.OpenAcc acc) = compilePAcc' aenv dest acc compilePAcc' :: AVarEnv aenv -> TupANames t -> A.PreOpenAcc A.OpenAcc aenv t -> SC [Command] compilePAcc' aenv destnames = \case A.Alet lhs rhs body -> do (names, aenv') <- genVarsAEnv lhs aenv let sts1 = [C.SDecl t n Nothing | TypedAName t n <- itupList names] let cmds1 = [CChunk [] sts1 []] cmds2 <- compileAcc' aenv names rhs cmds3 <- compileAcc' aenv' destnames body return (cmds1 ++ cmds2 ++ cmds3) A.Avar (Var _ idx) -> return (Right ([], ITupSingle (C.EVar (aveprj aenv idx)))) A.Apair a b -> do res1 <- compileAcc' aenv a res2 <- compileAcc' aenv b return (Left (\case ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2 ITupIgnore -> [] ITupSingle _ -> error "wat")) _ -> throw "Unsupported Acc constructor" where toStExprs :: TypeR t -> Either (ANames t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t) toStExprs ty (Left fun) = do names <- genAVars ty let sts1 = fun names return (sts1, itupmap (\(TypedName _ n) -> C.EVar n) names) toStExprs _ (Right pair) = return pair toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt] toStoring (Left f) = f toStoring (Right (sts, exs)) = (sts ++) . flip go exs where go :: Names t -> Exprs t -> [C.Stmt] go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex] go (ITupSingle _) _ = error "wat" go ITupIgnore _ = [] go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2 go (ITupPair _ _) _ = error "wat" genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv') genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env) genVarsAEnv (LeftHandSideSingle (ArrayR _ ty)) env = do name <- genName "a" ty' <- cvtType ty return (ITupSingle (TypedAName ty' name), AVEPush _ name env) genVarsAEnv (LeftHandSidePair lhs1 lhs2) env = do (n1, env1) <- genVarsAEnv lhs1 env (n2, env2) <- genVarsAEnv lhs2 env1 return (ITupPair n1 n2, env2) genAVars :: TypeR t -> SC (ANames t) genAVars TupRunit = return ITupIgnore genAVars (TupRsingle ty) = genAVar ty genAVars (TupRpair t1 t2) = ITupPair <$> genAVars t1 <*> genAVars t2 genAVar :: ScalarType t -> SC (ANames t) genAVar ty = ITupSingle <$> (TypedAName <$> cvtType ty <*> genName "a")