diff options
Diffstat (limited to 'SC/Acc.hs')
-rw-r--r-- | SC/Acc.hs | 72 |
1 files changed, 67 insertions, 5 deletions
@@ -1,5 +1,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ViewPatterns #-} module SC.Acc where import qualified Data.Array.Accelerate.AST as A @@ -14,6 +15,7 @@ import qualified Data.Set as Set import qualified Language.C as C import SC.Defs +import SC.Exp import SC.Monad @@ -23,7 +25,7 @@ data Command [C.Name] -- ^ Array variables used | CAlloc [C.FunDef] -- ^ Emitted top-level function definitions C.Type -- ^ Element type of the allocated array - C.Name -- ^ Variable to store it in + C.Name -- ^ Variable to store it in (newly declared!) C.StExpr -- ^ Code that computes the array size | CKeepalive C.Name -- ^ Never deallocate this | CDealloc C.Name @@ -82,7 +84,7 @@ compilePAcc' aenv destnames = \case | (TypedName _ destn, TypedName _ srcn) <- zip (shnamesList destshnames) (shnamesList shnames)] ++ [C.SAsg destn (C.EVar srcn) - | (destn, srcn) <- zipDestSrcNames destarrnames arrnames] + | (destn, srcn) <- zipDestSrcNamesAA destarrnames arrnames] usedA = map (\(TypedAName _ n) -> n) (itupList arrnames) return [CChunk [] sts usedA] @@ -92,15 +94,75 @@ compilePAcc' aenv destnames = \case res2 <- compileAcc' aenv destnames2 b return (res1 ++ res2) + A.Generate _ she fun@(A.Lam _ (A.Body (A.expType -> restype))) + | ANArray destshnames destarrnames <- destnames -> do + CompiledFun sheFD sheArgbuilder usedAshe <- compileExp aenv she + CompiledFun funFD funArgbuilder usedAfun <- compileFun aenv fun + tempnames <- genVars restype + loops <- enumShapeNested destshnames $ \idxnames linidxexpr -> concat + [[C.SCall (C.fundefName funFD) + (funArgbuilder (itupEvars (fromShNames idxnames)) tempnames)] + ,[C.SStore arrname linidxexpr (C.EVar tempname) + | (arrname, tempname) <- zipDestSrcNamesAE destarrnames tempnames]] + return . concat $ + [[CChunk [sheFD] + [C.SCall (C.fundefName sheFD) + (sheArgbuilder ITupIgnore (fromShNames destshnames))] + (map (\(TypedAName _ n) -> n) usedAshe)] + ,[CAlloc [] eltty n (C.StExpr [] (computeSize destshnames)) + | TypedAName arrty n <- itupList destarrnames + , let C.TPtr eltty = arrty] + ,[CChunk [funFD] + loops + (map (\(TypedAName _ n) -> n) (itupList destarrnames ++ usedAfun))]] + _ -> throw "Unsupported Acc constructor" -zipDestSrcNames :: ANames e -> ANames e -> [(C.Name, C.Name)] +-- | Returns an expression of type int64_t +computeSize :: ShNames sh -> C.Expr +computeSize ShZ = C.ELit "1LL" +computeSize (ShS n ShZ) = C.EVar n +computeSize (ShS n ns) = C.EOp (C.EVar n) "*" (computeSize ns) + +-- | Given size variables and index variables, returns an expression of type int64_t +linearIndexExpr :: ShNames sh -> ShNames sh -> C.Expr +linearIndexExpr ShZ ShZ = C.ELit "1LL" +linearIndexExpr (ShS _ ShZ) (ShS i ShZ) = C.EVar i +linearIndexExpr (ShS n ns) (ShS i is) = + C.EOp (C.EOp (linearIndexExpr ns is) "*" (C.EVar n)) "+" (C.EVar i) + +zipDestSrcNames :: ITup C.Name e -> ITup C.Name e -> [(C.Name, C.Name)] zipDestSrcNames ITupIgnore _ = [] -zipDestSrcNames _ ITupIgnore = error "Ignore in source names where there is none in the destination names" -zipDestSrcNames (ITupSingle (TypedAName _ n)) (ITupSingle (TypedAName _ n')) = [(n, n')] +zipDestSrcNames _ ITupIgnore = error "Ignore in source names but not in destination names" +zipDestSrcNames (ITupSingle n) (ITupSingle n') = [(n, n')] zipDestSrcNames (ITupPair a b) (ITupPair a' b') = zipDestSrcNames a a' ++ zipDestSrcNames b b' zipDestSrcNames _ _ = error "wat" +zipDestSrcNamesAA :: ANames e -> ANames e -> [(C.Name, C.Name)] +zipDestSrcNamesAA ns1 ns2 = + zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1) + (itupmap (\(TypedAName _ n) -> n) ns2) + +zipDestSrcNamesAE :: ANames e -> Names e -> [(C.Name, C.Name)] +zipDestSrcNamesAE ns1 ns2 = + zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1) + (itupmap (\(TypedName _ n) -> n) ns2) + +-- | Given: +-- - shape size variables +-- - a function taking +-- - index variables +-- - an expression computing the linear index from the size and index variables +-- returns a nested loop statement where the loop body is given by the function. +enumShapeNested :: ShNames sh -> (ShNames sh -> C.Expr -> [C.Stmt]) -> SC [C.Stmt] +enumShapeNested sizenames fun = do + idxnames <- genShNames (shNamesShape sizenames) + let makeLoops :: ShNames sh -> ShNames sh -> [C.Stmt] -> [C.Stmt] + makeLoops ShZ ShZ body = body + makeLoops (ShS n ns) (ShS i is) body = + makeLoops ns is [C.SFor (C.TInt C.B64) i (C.ELit "0") (C.EVar n) body] + return (makeLoops sizenames idxnames (fun idxnames (linearIndexExpr sizenames idxnames))) + genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv') genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env) genVarsAEnv (LeftHandSideSingle (ArrayR sht ty)) env = do |