summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2021-09-24 22:49:44 +0200
committerTom Smeding <tom@tomsmeding.com>2021-09-24 22:49:44 +0200
commit070772f008bcb5edb63f3f2c2c5f10c4eb9cb008 (patch)
tree4b3278f339c140b17f7f2cc9d9bbcee83235526a
parentf42d7f4562ea2e5c9ef634665952e38630f17ae4 (diff)
Potentially generate some code for Generate
-rw-r--r--Language/C.hs5
-rw-r--r--SC/Acc.hs72
-rw-r--r--SC/Defs.hs18
-rw-r--r--SC/Exp.hs3
4 files changed, 93 insertions, 5 deletions
diff --git a/Language/C.hs b/Language/C.hs
index 86250dd..35cf432 100644
--- a/Language/C.hs
+++ b/Language/C.hs
@@ -47,3 +47,8 @@ data Expr
| EPtrTo Expr
| ESizeOf Type
deriving (Show, Eq)
+
+
+fundefName :: FunDef -> Name
+fundefName (FunDef _ n _ _) = n
+fundefName (ProcDef n _ _) = n
diff --git a/SC/Acc.hs b/SC/Acc.hs
index 1424f52..b50bf24 100644
--- a/SC/Acc.hs
+++ b/SC/Acc.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ViewPatterns #-}
module SC.Acc where
import qualified Data.Array.Accelerate.AST as A
@@ -14,6 +15,7 @@ import qualified Data.Set as Set
import qualified Language.C as C
import SC.Defs
+import SC.Exp
import SC.Monad
@@ -23,7 +25,7 @@ data Command
[C.Name] -- ^ Array variables used
| CAlloc [C.FunDef] -- ^ Emitted top-level function definitions
C.Type -- ^ Element type of the allocated array
- C.Name -- ^ Variable to store it in
+ C.Name -- ^ Variable to store it in (newly declared!)
C.StExpr -- ^ Code that computes the array size
| CKeepalive C.Name -- ^ Never deallocate this
| CDealloc C.Name
@@ -82,7 +84,7 @@ compilePAcc' aenv destnames = \case
| (TypedName _ destn, TypedName _ srcn) <- zip (shnamesList destshnames) (shnamesList shnames)]
++
[C.SAsg destn (C.EVar srcn)
- | (destn, srcn) <- zipDestSrcNames destarrnames arrnames]
+ | (destn, srcn) <- zipDestSrcNamesAA destarrnames arrnames]
usedA = map (\(TypedAName _ n) -> n) (itupList arrnames)
return [CChunk [] sts usedA]
@@ -92,15 +94,75 @@ compilePAcc' aenv destnames = \case
res2 <- compileAcc' aenv destnames2 b
return (res1 ++ res2)
+ A.Generate _ she fun@(A.Lam _ (A.Body (A.expType -> restype)))
+ | ANArray destshnames destarrnames <- destnames -> do
+ CompiledFun sheFD sheArgbuilder usedAshe <- compileExp aenv she
+ CompiledFun funFD funArgbuilder usedAfun <- compileFun aenv fun
+ tempnames <- genVars restype
+ loops <- enumShapeNested destshnames $ \idxnames linidxexpr -> concat
+ [[C.SCall (C.fundefName funFD)
+ (funArgbuilder (itupEvars (fromShNames idxnames)) tempnames)]
+ ,[C.SStore arrname linidxexpr (C.EVar tempname)
+ | (arrname, tempname) <- zipDestSrcNamesAE destarrnames tempnames]]
+ return . concat $
+ [[CChunk [sheFD]
+ [C.SCall (C.fundefName sheFD)
+ (sheArgbuilder ITupIgnore (fromShNames destshnames))]
+ (map (\(TypedAName _ n) -> n) usedAshe)]
+ ,[CAlloc [] eltty n (C.StExpr [] (computeSize destshnames))
+ | TypedAName arrty n <- itupList destarrnames
+ , let C.TPtr eltty = arrty]
+ ,[CChunk [funFD]
+ loops
+ (map (\(TypedAName _ n) -> n) (itupList destarrnames ++ usedAfun))]]
+
_ -> throw "Unsupported Acc constructor"
-zipDestSrcNames :: ANames e -> ANames e -> [(C.Name, C.Name)]
+-- | Returns an expression of type int64_t
+computeSize :: ShNames sh -> C.Expr
+computeSize ShZ = C.ELit "1LL"
+computeSize (ShS n ShZ) = C.EVar n
+computeSize (ShS n ns) = C.EOp (C.EVar n) "*" (computeSize ns)
+
+-- | Given size variables and index variables, returns an expression of type int64_t
+linearIndexExpr :: ShNames sh -> ShNames sh -> C.Expr
+linearIndexExpr ShZ ShZ = C.ELit "1LL"
+linearIndexExpr (ShS _ ShZ) (ShS i ShZ) = C.EVar i
+linearIndexExpr (ShS n ns) (ShS i is) =
+ C.EOp (C.EOp (linearIndexExpr ns is) "*" (C.EVar n)) "+" (C.EVar i)
+
+zipDestSrcNames :: ITup C.Name e -> ITup C.Name e -> [(C.Name, C.Name)]
zipDestSrcNames ITupIgnore _ = []
-zipDestSrcNames _ ITupIgnore = error "Ignore in source names where there is none in the destination names"
-zipDestSrcNames (ITupSingle (TypedAName _ n)) (ITupSingle (TypedAName _ n')) = [(n, n')]
+zipDestSrcNames _ ITupIgnore = error "Ignore in source names but not in destination names"
+zipDestSrcNames (ITupSingle n) (ITupSingle n') = [(n, n')]
zipDestSrcNames (ITupPair a b) (ITupPair a' b') = zipDestSrcNames a a' ++ zipDestSrcNames b b'
zipDestSrcNames _ _ = error "wat"
+zipDestSrcNamesAA :: ANames e -> ANames e -> [(C.Name, C.Name)]
+zipDestSrcNamesAA ns1 ns2 =
+ zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1)
+ (itupmap (\(TypedAName _ n) -> n) ns2)
+
+zipDestSrcNamesAE :: ANames e -> Names e -> [(C.Name, C.Name)]
+zipDestSrcNamesAE ns1 ns2 =
+ zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1)
+ (itupmap (\(TypedName _ n) -> n) ns2)
+
+-- | Given:
+-- - shape size variables
+-- - a function taking
+-- - index variables
+-- - an expression computing the linear index from the size and index variables
+-- returns a nested loop statement where the loop body is given by the function.
+enumShapeNested :: ShNames sh -> (ShNames sh -> C.Expr -> [C.Stmt]) -> SC [C.Stmt]
+enumShapeNested sizenames fun = do
+ idxnames <- genShNames (shNamesShape sizenames)
+ let makeLoops :: ShNames sh -> ShNames sh -> [C.Stmt] -> [C.Stmt]
+ makeLoops ShZ ShZ body = body
+ makeLoops (ShS n ns) (ShS i is) body =
+ makeLoops ns is [C.SFor (C.TInt C.B64) i (C.ELit "0") (C.EVar n) body]
+ return (makeLoops sizenames idxnames (fun idxnames (linearIndexExpr sizenames idxnames)))
+
genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv')
genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env)
genVarsAEnv (LeftHandSideSingle (ArrayR sht ty)) env = do
diff --git a/SC/Defs.hs b/SC/Defs.hs
index bb8e03f..fac4e33 100644
--- a/SC/Defs.hs
+++ b/SC/Defs.hs
@@ -4,6 +4,7 @@ module SC.Defs where
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Representation.Array
+import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Type
import qualified Language.C as C
@@ -59,6 +60,9 @@ type ANames = ITup TypedAName
type Exprs = ITup C.Expr
+itupEvars :: ITup TypedName t -> Exprs t
+itupEvars = itupmap (\(TypedName _ n) -> C.EVar n)
+
-- Type is the pointer type of the type that this name is supposed to be according to the type index.
data TypedAName = TypedAName C.Type Name
@@ -84,6 +88,20 @@ shnamesList :: ShNames sh -> [TypedName]
shnamesList ShZ = []
shnamesList (ShS n shns) = TypedName (C.TInt C.B64) n : shnamesList shns
+makeShNames :: ShapeR sh -> ITup TypedName sh -> ShNames sh
+makeShNames ShapeRz ITupIgnore = ShZ
+makeShNames (ShapeRsnoc sht) (ITupPair ns (ITupSingle (TypedName _ n))) =
+ ShS n (makeShNames sht ns)
+makeShNames _ _ = error "wat"
+
+fromShNames :: ShNames sh -> ITup TypedName sh
+fromShNames ShZ = ITupIgnore
+fromShNames (ShS n ns) = ITupPair (fromShNames ns) (ITupSingle (TypedName (C.TInt C.B64) n))
+
+shNamesShape :: ShNames sh -> ShapeR sh
+shNamesShape ShZ = ShapeRz
+shNamesShape (ShS _ ns) = ShapeRsnoc (shNamesShape ns)
+
-- GENERATING VARIABLE NAMES
-- -------------------------
diff --git a/SC/Exp.hs b/SC/Exp.hs
index d033cc8..2bd2b37 100644
--- a/SC/Exp.hs
+++ b/SC/Exp.hs
@@ -27,6 +27,8 @@ data CompiledFun aenv t1 t2 =
-- the given names.
-- The arguments will refer to array variable names found in the
-- original array environment.
+ [TypedAName]
+ -- ^ Arrays that the constructed arguments use from the environment
-- | The function must be single-argument. Uncurry if necessary (e.g. for zipWith).
compileFun :: AVarEnv aenv -> A.Fun aenv (t1 -> t2) -> SC (CompiledFun aenv t1 t2)
@@ -48,6 +50,7 @@ compileFun aenv (A.Lam lhs (A.Body body)) = do
map (\(TypedAName _ n) -> C.EVar n) usedA
++ itupList argexprs
++ map (\(TypedName _ n) -> C.EPtrTo (C.EVar n)) (itupList destnames))
+ usedA
where
genoutstores :: Names t -> Exprs t -> [C.Stmt]
genoutstores ITupIgnore _ = []