{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ViewPatterns #-} module SC.Acc where import qualified Data.Array.Accelerate.AST as A import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape hiding (zip) import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import Data.Bifunctor import qualified Data.Set as Set import qualified Language.C as C import SC.Defs import SC.Exp import SC.Monad data Command = CChunk [C.FunDef] -- ^ Emitted top-level function definitions [C.Stmt] -- ^ Code to execute [C.Name] -- ^ Array variables used | CAlloc [C.FunDef] -- ^ Emitted top-level function definitions C.Type -- ^ Element type of the allocated array C.Name -- ^ Variable to store it in (newly declared!) C.StExpr -- ^ Code that computes the array size | CKeepalive C.Name -- ^ Never deallocate this | CDealloc C.Name deriving (Show) insertDeallocs :: [Command] -> [Command] insertDeallocs cmds = let collectable = Set.fromList [n | CAlloc _ _ n _ <- cmds] `Set.difference` Set.fromList [n | CKeepalive n <- cmds] in fst $ foldr (\cmd (rest, done) -> case cmd of CChunk _ _ used -> let todealloc = filter (\n -> n `Set.member` collectable && n `Set.notMember` done) used in (cmd : map CDealloc todealloc ++ rest ,done `Set.union` Set.fromList todealloc) CAlloc _ _ name _ | name `Set.notMember` done -> (rest, done) -- unused alloc | otherwise -> (cmd : rest, Set.delete name done) CKeepalive _ -> (rest, done) -- already handled above in @collectable@ CDealloc _ -> error "insertDeallocs: CDealloc found") ([], mempty) cmds compileCommands :: [Command] -> ([C.FunDef], [C.Stmt]) compileCommands [] = ([], []) compileCommands (CChunk defs code _ : cmds) = bimap (defs ++) (code ++) (compileCommands cmds) compileCommands (CAlloc defs typ name (C.StExpr szstmts szexpr) : cmds) = let allocstmt = C.SDecl (C.TPtr typ) name (Just (C.ECall (C.Name "malloc") [C.EOp szexpr "*" (C.ESizeOf typ)])) in bimap (defs ++) ((szstmts ++ [allocstmt]) ++) (compileCommands cmds) compileCommands (CDealloc name : cmds) = second ([C.SCall (C.Name "free") [C.EVar name]] ++) (compileCommands cmds) compileCommands (CKeepalive _ : cmds) = compileCommands cmds compileAcc' :: AVarEnv aenv -> TupANames t -> A.OpenAcc aenv t -> SC [Command] compileAcc' aenv dest (A.OpenAcc acc) = compilePAcc' aenv dest acc compilePAcc' :: AVarEnv aenv -> TupANames t -> A.PreOpenAcc A.OpenAcc aenv t -> SC [Command] compilePAcc' aenv destnames = \case A.Alet lhs rhs body -> do (names, aenv') <- genVarsAEnv lhs aenv let sts1sh = [C.SDecl t n Nothing | TypedName t n <- fst (tupanamesList names)] sts1arr = [C.SDecl t n Nothing | TypedAName t n <- snd (tupanamesList names)] let cmds1 = [CChunk [] (sts1sh ++ sts1arr) []] cmds2 <- compileAcc' aenv names rhs cmds3 <- compileAcc' aenv' destnames body return (cmds1 ++ cmds2 ++ cmds3) A.Avar (Var _ idx) | ANArray destshnames destarrnames <- destnames -> do let (shnames, arrnames) = aveprj aenv idx sts = [C.SAsg destn (C.EVar srcn) | (TypedName _ destn, TypedName _ srcn) <- zip (shnamesList destshnames) (shnamesList shnames)] ++ [C.SAsg destn (C.EVar srcn) | (destn, srcn) <- zipDestSrcNamesAA destarrnames arrnames] usedA = map (\(TypedAName _ n) -> n) (itupList arrnames) return [CChunk [] sts usedA] A.Anil -> return [] A.Apair a b | ANPair destnames1 destnames2 <- destnames -> do res1 <- compileAcc' aenv destnames1 a res2 <- compileAcc' aenv destnames2 b return (res1 ++ res2) A.Generate _ she fun@(A.Lam _ (A.Body (A.expType -> restype))) | ANArray destshnames destarrnames <- destnames -> do CompiledFun sheFD sheArgbuilder usedAshe <- compileExp aenv she CompiledFun funFD funArgbuilder usedAfun <- compileFun aenv fun tempnames <- genVars restype loops <- enumShapeNested destshnames $ \idxnames linidxexpr -> concat [[C.SDecl t n Nothing | TypedName t n <- itupList tempnames] ,[C.SCall (C.fundefName funFD) (funArgbuilder (itupEvars (fromShNames idxnames)) tempnames)] ,[C.SStore arrname linidxexpr (C.EVar tempname) | (arrname, tempname) <- zipDestSrcNamesAE destarrnames tempnames]] return . concat $ [[CChunk [sheFD] [C.SCall (C.fundefName sheFD) (sheArgbuilder ITupIgnore (fromShNames destshnames))] (concatMap (\(SomeArray _ ans) -> map (\(TypedAName _ n) -> n) (itupList ans)) usedAshe)] ,[CAlloc [] eltty n (C.StExpr [] (computeSize destshnames)) | TypedAName arrty n <- itupList destarrnames , let C.TPtr eltty = arrty] ,[CChunk [funFD] loops (map (\(TypedAName _ n) -> n) (itupList destarrnames ++ concatMap (\(SomeArray _ ans) -> itupList ans) usedAfun))]] _ -> throw "Unsupported Acc constructor" -- | Returns an expression of type int64_t computeSize :: ShNames sh -> C.Expr computeSize ShZ = C.ELit "1LL" computeSize (ShS ShZ n) = C.EVar n computeSize (ShS ns n) = C.EOp (computeSize ns) "*" (C.EVar n) -- | Given size variables and index variables, returns an expression of type int64_t linearIndexExpr :: ShNames sh -> ShNames sh -> C.Expr linearIndexExpr ShZ ShZ = C.ELit "1LL" linearIndexExpr (ShS ShZ _) (ShS ShZ i) = C.EVar i linearIndexExpr (ShS ns n) (ShS is i) = C.EOp (C.EOp (linearIndexExpr ns is) "*" (C.EVar n)) "+" (C.EVar i) zipDestSrcNames :: ITup C.Name t -> ITup C.Name t -> [(C.Name, C.Name)] zipDestSrcNames ITupIgnore _ = [] zipDestSrcNames _ ITupIgnore = error "Ignore in source names but not in destination names" zipDestSrcNames (ITupSingle n) (ITupSingle n') = [(n, n')] zipDestSrcNames (ITupPair a b) (ITupPair a' b') = zipDestSrcNames a a' ++ zipDestSrcNames b b' zipDestSrcNames _ _ = error "wat" zipDestSrcNamesAA :: ANames t -> ANames t -> [(C.Name, C.Name)] zipDestSrcNamesAA ns1 ns2 = zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1) (itupmap (\(TypedAName _ n) -> n) ns2) zipDestSrcNamesAE :: ANames t -> Names t -> [(C.Name, C.Name)] zipDestSrcNamesAE ns1 ns2 = zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1) (itupmap (\(TypedName _ n) -> n) ns2) -- | Given: -- - shape size variables -- - a function taking -- - index variables -- - an expression computing the linear index from the size and index variables -- returns a nested loop statement where the loop body is given by the function. enumShapeNested :: ShNames sh -> (ShNames sh -> C.Expr -> [C.Stmt]) -> SC [C.Stmt] enumShapeNested sizenames fun = do idxnames <- genShNames (shNamesShape sizenames) let makeLoops :: ShNames sh -> ShNames sh -> [C.Stmt] -> [C.Stmt] makeLoops ShZ ShZ body = body makeLoops (ShS ns n) (ShS is i) body = makeLoops ns is [C.SFor (C.TInt C.B64) i (C.ELit "0") (C.EVar n) body] return (makeLoops sizenames idxnames (fun idxnames (linearIndexExpr sizenames idxnames))) genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv') genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env) genVarsAEnv (LeftHandSideSingle (ArrayR sht ty)) env = do shnames <- genShNames sht names <- genAVars ty return (ANArray shnames names, AVEPush shnames names env) genVarsAEnv (LeftHandSidePair lhs1 lhs2) env = do (n1, env1) <- genVarsAEnv lhs1 env (n2, env2) <- genVarsAEnv lhs2 env1 return (ANPair n1 n2, env2) genAVarsTup :: ArraysR t -> SC (TupANames t) genAVarsTup TupRunit = return ANIgnore genAVarsTup (TupRsingle (ArrayR sht ty)) = ANArray <$> genShNames sht <*> genAVars ty genAVarsTup (TupRpair t1 t2) = ANPair <$> genAVarsTup t1 <*> genAVarsTup t2 genAVars :: TypeR t -> SC (ANames t) genAVars TupRunit = return ITupIgnore genAVars (TupRsingle ty) = genAVar ty genAVars (TupRpair t1 t2) = ITupPair <$> genAVars t1 <*> genAVars t2 genShNames :: ShapeR sh -> SC (ShNames sh) genShNames ShapeRz = return ShZ genShNames (ShapeRsnoc sht) = do names <- genShNames sht name <- genName "n" return (ShS names name) genAVar :: ScalarType t -> SC (ANames t) genAVar ty = ITupSingle <$> (TypedAName <$> fmap C.TPtr (cvtType ty) <*> genName "a")