{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} module SC.Acc where import qualified Data.Array.Accelerate.AST as A import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape hiding (zip) import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import Data.Bifunctor import qualified Data.Set as Set import qualified Language.C as C import SC.Defs import SC.Monad data Command = CChunk [C.FunDef] -- ^ Emitted top-level function definitions [C.Stmt] -- ^ Code to execute [C.Name] -- ^ Array variables used | CAlloc [C.FunDef] -- ^ Emitted top-level function definitions C.Type -- ^ Element type of the allocated array C.Name -- ^ Variable to store it in C.StExpr -- ^ Code that computes the array size | CKeepalive C.Name -- ^ Never deallocate this | CDealloc C.Name deriving (Show) insertDeallocs :: [Command] -> [Command] insertDeallocs cmds = let collectable = Set.fromList [n | CAlloc _ _ n _ <- cmds] `Set.difference` Set.fromList [n | CKeepalive n <- cmds] in fst $ foldr (\cmd (rest, done) -> case cmd of CChunk _ _ used -> let todealloc = filter (\n -> n `Set.member` collectable && n `Set.notMember` done) used in (cmd : map CDealloc todealloc ++ rest ,done `Set.union` Set.fromList todealloc) CAlloc _ _ name _ | name `Set.notMember` done -> (rest, done) -- unused alloc | otherwise -> (cmd : rest, Set.delete name done) CKeepalive _ -> (rest, done) -- already handled above in @collectable@ CDealloc _ -> error "insertDeallocs: CDealloc found") ([], mempty) cmds compileCommands :: [Command] -> ([C.FunDef], [C.Stmt]) compileCommands [] = ([], []) compileCommands (CChunk defs code _ : cmds) = bimap (defs ++) (code ++) (compileCommands cmds) compileCommands (CAlloc defs typ name (C.StExpr szstmts szexpr) : cmds) = let allocstmt = C.SDecl (C.TPtr typ) name (Just (C.ECall (C.Name "malloc") [C.EOp szexpr "*" (C.ESizeOf typ)])) in bimap (defs ++) ((szstmts ++ [allocstmt]) ++) (compileCommands cmds) compileCommands (CDealloc name : cmds) = second ([C.SCall (C.Name "free") [C.EVar name]] ++) (compileCommands cmds) compileCommands (CKeepalive _ : cmds) = compileCommands cmds compileAcc' :: AVarEnv aenv -> TupANames t -> A.OpenAcc aenv t -> SC [Command] compileAcc' aenv dest (A.OpenAcc acc) = compilePAcc' aenv dest acc compilePAcc' :: AVarEnv aenv -> TupANames t -> A.PreOpenAcc A.OpenAcc aenv t -> SC [Command] compilePAcc' aenv destnames = \case A.Alet lhs rhs body -> do (names, aenv') <- genVarsAEnv lhs aenv let sts1sh = [C.SDecl t n Nothing | TypedName t n <- fst (tupanamesList names)] sts1arr = [C.SDecl t n Nothing | TypedAName t n <- snd (tupanamesList names)] let cmds1 = [CChunk [] (sts1sh ++ sts1arr) []] cmds2 <- compileAcc' aenv names rhs cmds3 <- compileAcc' aenv' destnames body return (cmds1 ++ cmds2 ++ cmds3) A.Avar (Var _ idx) | ANArray destshnames destarrnames <- destnames -> do let (shnames, arrnames) = aveprj aenv idx sts = [C.SAsg destn (C.EVar srcn) | (TypedName _ destn, TypedName _ srcn) <- zip (shnamesList destshnames) (shnamesList shnames)] ++ [C.SAsg destn (C.EVar srcn) | (destn, srcn) <- zipDestSrcNames destarrnames arrnames] usedA = map (\(TypedAName _ n) -> n) (itupList arrnames) return [CChunk [] sts usedA] A.Apair a b | ANPair destnames1 destnames2 <- destnames -> do res1 <- compileAcc' aenv destnames1 a res2 <- compileAcc' aenv destnames2 b return (res1 ++ res2) _ -> throw "Unsupported Acc constructor" zipDestSrcNames :: ANames e -> ANames e -> [(C.Name, C.Name)] zipDestSrcNames ITupIgnore _ = [] zipDestSrcNames _ ITupIgnore = error "Ignore in source names where there is none in the destination names" zipDestSrcNames (ITupSingle (TypedAName _ n)) (ITupSingle (TypedAName _ n')) = [(n, n')] zipDestSrcNames (ITupPair a b) (ITupPair a' b') = zipDestSrcNames a a' ++ zipDestSrcNames b b' zipDestSrcNames _ _ = error "wat" genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv') genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env) genVarsAEnv (LeftHandSideSingle (ArrayR sht ty)) env = do shnames <- genShNames sht names <- genAVars ty return (ANArray shnames names, AVEPush shnames names env) genVarsAEnv (LeftHandSidePair lhs1 lhs2) env = do (n1, env1) <- genVarsAEnv lhs1 env (n2, env2) <- genVarsAEnv lhs2 env1 return (ANPair n1 n2, env2) genAVars :: TypeR t -> SC (ANames t) genAVars TupRunit = return ITupIgnore genAVars (TupRsingle ty) = genAVar ty genAVars (TupRpair t1 t2) = ITupPair <$> genAVars t1 <*> genAVars t2 genShNames :: ShapeR sh -> SC (ShNames sh) genShNames ShapeRz = return ShZ genShNames (ShapeRsnoc sht) = do name <- genName "n" names <- genShNames sht return (ShS name names) genAVar :: ScalarType t -> SC (ANames t) genAVar ty = ITupSingle <$> (TypedAName <$> cvtType ty <*> genName "a")