summaryrefslogtreecommitdiff
path: root/SC/Acc.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2021-09-24 22:49:44 +0200
committerTom Smeding <tom@tomsmeding.com>2021-09-24 22:49:44 +0200
commit070772f008bcb5edb63f3f2c2c5f10c4eb9cb008 (patch)
tree4b3278f339c140b17f7f2cc9d9bbcee83235526a /SC/Acc.hs
parentf42d7f4562ea2e5c9ef634665952e38630f17ae4 (diff)
Potentially generate some code for Generate
Diffstat (limited to 'SC/Acc.hs')
-rw-r--r--SC/Acc.hs72
1 files changed, 67 insertions, 5 deletions
diff --git a/SC/Acc.hs b/SC/Acc.hs
index 1424f52..b50bf24 100644
--- a/SC/Acc.hs
+++ b/SC/Acc.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ViewPatterns #-}
module SC.Acc where
import qualified Data.Array.Accelerate.AST as A
@@ -14,6 +15,7 @@ import qualified Data.Set as Set
import qualified Language.C as C
import SC.Defs
+import SC.Exp
import SC.Monad
@@ -23,7 +25,7 @@ data Command
[C.Name] -- ^ Array variables used
| CAlloc [C.FunDef] -- ^ Emitted top-level function definitions
C.Type -- ^ Element type of the allocated array
- C.Name -- ^ Variable to store it in
+ C.Name -- ^ Variable to store it in (newly declared!)
C.StExpr -- ^ Code that computes the array size
| CKeepalive C.Name -- ^ Never deallocate this
| CDealloc C.Name
@@ -82,7 +84,7 @@ compilePAcc' aenv destnames = \case
| (TypedName _ destn, TypedName _ srcn) <- zip (shnamesList destshnames) (shnamesList shnames)]
++
[C.SAsg destn (C.EVar srcn)
- | (destn, srcn) <- zipDestSrcNames destarrnames arrnames]
+ | (destn, srcn) <- zipDestSrcNamesAA destarrnames arrnames]
usedA = map (\(TypedAName _ n) -> n) (itupList arrnames)
return [CChunk [] sts usedA]
@@ -92,15 +94,75 @@ compilePAcc' aenv destnames = \case
res2 <- compileAcc' aenv destnames2 b
return (res1 ++ res2)
+ A.Generate _ she fun@(A.Lam _ (A.Body (A.expType -> restype)))
+ | ANArray destshnames destarrnames <- destnames -> do
+ CompiledFun sheFD sheArgbuilder usedAshe <- compileExp aenv she
+ CompiledFun funFD funArgbuilder usedAfun <- compileFun aenv fun
+ tempnames <- genVars restype
+ loops <- enumShapeNested destshnames $ \idxnames linidxexpr -> concat
+ [[C.SCall (C.fundefName funFD)
+ (funArgbuilder (itupEvars (fromShNames idxnames)) tempnames)]
+ ,[C.SStore arrname linidxexpr (C.EVar tempname)
+ | (arrname, tempname) <- zipDestSrcNamesAE destarrnames tempnames]]
+ return . concat $
+ [[CChunk [sheFD]
+ [C.SCall (C.fundefName sheFD)
+ (sheArgbuilder ITupIgnore (fromShNames destshnames))]
+ (map (\(TypedAName _ n) -> n) usedAshe)]
+ ,[CAlloc [] eltty n (C.StExpr [] (computeSize destshnames))
+ | TypedAName arrty n <- itupList destarrnames
+ , let C.TPtr eltty = arrty]
+ ,[CChunk [funFD]
+ loops
+ (map (\(TypedAName _ n) -> n) (itupList destarrnames ++ usedAfun))]]
+
_ -> throw "Unsupported Acc constructor"
-zipDestSrcNames :: ANames e -> ANames e -> [(C.Name, C.Name)]
+-- | Returns an expression of type int64_t
+computeSize :: ShNames sh -> C.Expr
+computeSize ShZ = C.ELit "1LL"
+computeSize (ShS n ShZ) = C.EVar n
+computeSize (ShS n ns) = C.EOp (C.EVar n) "*" (computeSize ns)
+
+-- | Given size variables and index variables, returns an expression of type int64_t
+linearIndexExpr :: ShNames sh -> ShNames sh -> C.Expr
+linearIndexExpr ShZ ShZ = C.ELit "1LL"
+linearIndexExpr (ShS _ ShZ) (ShS i ShZ) = C.EVar i
+linearIndexExpr (ShS n ns) (ShS i is) =
+ C.EOp (C.EOp (linearIndexExpr ns is) "*" (C.EVar n)) "+" (C.EVar i)
+
+zipDestSrcNames :: ITup C.Name e -> ITup C.Name e -> [(C.Name, C.Name)]
zipDestSrcNames ITupIgnore _ = []
-zipDestSrcNames _ ITupIgnore = error "Ignore in source names where there is none in the destination names"
-zipDestSrcNames (ITupSingle (TypedAName _ n)) (ITupSingle (TypedAName _ n')) = [(n, n')]
+zipDestSrcNames _ ITupIgnore = error "Ignore in source names but not in destination names"
+zipDestSrcNames (ITupSingle n) (ITupSingle n') = [(n, n')]
zipDestSrcNames (ITupPair a b) (ITupPair a' b') = zipDestSrcNames a a' ++ zipDestSrcNames b b'
zipDestSrcNames _ _ = error "wat"
+zipDestSrcNamesAA :: ANames e -> ANames e -> [(C.Name, C.Name)]
+zipDestSrcNamesAA ns1 ns2 =
+ zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1)
+ (itupmap (\(TypedAName _ n) -> n) ns2)
+
+zipDestSrcNamesAE :: ANames e -> Names e -> [(C.Name, C.Name)]
+zipDestSrcNamesAE ns1 ns2 =
+ zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1)
+ (itupmap (\(TypedName _ n) -> n) ns2)
+
+-- | Given:
+-- - shape size variables
+-- - a function taking
+-- - index variables
+-- - an expression computing the linear index from the size and index variables
+-- returns a nested loop statement where the loop body is given by the function.
+enumShapeNested :: ShNames sh -> (ShNames sh -> C.Expr -> [C.Stmt]) -> SC [C.Stmt]
+enumShapeNested sizenames fun = do
+ idxnames <- genShNames (shNamesShape sizenames)
+ let makeLoops :: ShNames sh -> ShNames sh -> [C.Stmt] -> [C.Stmt]
+ makeLoops ShZ ShZ body = body
+ makeLoops (ShS n ns) (ShS i is) body =
+ makeLoops ns is [C.SFor (C.TInt C.B64) i (C.ELit "0") (C.EVar n) body]
+ return (makeLoops sizenames idxnames (fun idxnames (linearIndexExpr sizenames idxnames)))
+
genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv')
genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env)
genVarsAEnv (LeftHandSideSingle (ArrayR sht ty)) env = do