summaryrefslogtreecommitdiff
path: root/SC/Afun.hs
diff options
context:
space:
mode:
Diffstat (limited to 'SC/Afun.hs')
-rw-r--r--SC/Afun.hs141
1 files changed, 141 insertions, 0 deletions
diff --git a/SC/Afun.hs b/SC/Afun.hs
new file mode 100644
index 0000000..3379cc6
--- /dev/null
+++ b/SC/Afun.hs
@@ -0,0 +1,141 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE StandaloneDeriving #-}
+module SC.Afun where
+
+import qualified Data.Array.Accelerate.AST as A
+import Data.Array.Accelerate.AST.LeftHandSide
+import Data.Array.Accelerate.Representation.Array
+import Data.Array.Accelerate.Representation.Type
+
+import qualified Language.C as C
+import SC.Acc
+import SC.Defs
+import SC.Monad
+
+
+-- | Variable names for a tuple of arrays. Each array is represented in
+-- struct-of-arrays form. For example: the type
+-- @(Scalar Double, Matrix (Int, Float))@, which is internally represented as
+-- @(Array () Double, Array (((), Int), Int) (Int, Float))@, would be
+-- described as follows: (the variable names will differ)
+--
+-- > CATNPair (CATNArray CSNNil
+-- > (CANName TDouble (Name "a")))
+-- > (CATNArray (CSNSnoc (CSNSnoc CSNNil (Name "n1")) (Name "n2"))
+-- > (CANPair (CANName (TInt B64) (Name "b"))
+-- > (CANName TFloat (Name "c"))))
+--
+-- Suppose that the Accelerate function in question has return type
+-- @Vector Double@, which is to say @Array ((), Int) Double@, with description:
+--
+-- > CATNArray (CSNSnoc CSNNil (Name "m"))
+-- > (CANName TDouble (Name "r"))
+--
+-- Then its C function definition would look as follows:
+--
+-- > void function(double *a,
+-- > int64_t n1, int64_t n2, int64_t *b, int64_t *c,
+-- > int64_t m, double **r);
+--
+-- Note that the first input array array here has zero shape arguments because
+-- it is zero-dimensional.
+data CArrTupNames a where
+ CATNPair :: CArrTupNames a -> CArrTupNames b -> CArrTupNames (a, b)
+ CATNArray :: CShNames sh -> CArrNames sh a -> CArrTupNames (Array sh a)
+ CATNNil :: CArrTupNames ()
+deriving instance Show (CArrTupNames a)
+
+-- | Names for the shape of an array. See 'CArrTupNames' for more information.
+--
+-- Note that the names in this structure are are to be interpreted as variables
+-- of type @int64_t@.
+data CShNames sh where
+ CSNSnoc :: CShNames sh -> C.Name -> CShNames (sh, Int)
+ CSNNil :: CShNames ()
+deriving instance Show (CShNames a)
+
+-- | Names for a single array. See 'CArrTupNames' for more information.
+--
+-- Note that the 'C.Type' in 'CANName' is the /element/ type of the array.
+data CArrNames sh a where
+ CANPair :: CArrNames sh a -> CArrNames sh b -> CArrNames sh (a, b)
+ CANName :: C.Type -> C.Name -> CArrNames sh a
+ CANNil :: CArrNames sh ()
+deriving instance Show (CArrNames sh a)
+
+-- | The function passed should have exactly one argument (that may consist of
+-- multiple arrays in a tuple, of course).
+--
+-- The result consists of:
+-- 1. An array of auxiliary function definitions that the program needs, on top
+-- of the prelude.
+-- 2. The function that implements the top-level Accelerate array function.
+-- 3. The variable names corresponding to the components of the argument.
+-- 4. The variable names corresponding to the components of the result. These
+-- are double-pointer arguments to the function in (2.).
+--
+-- For an example, see the documentation of 'CArrTupNames'.
+compileAfun1 :: C.Name
+ -> A.Afun (a -> b)
+ -> SC ([C.FunDef], C.FunDef, CArrTupNames a, CArrTupNames b)
+compileAfun1 procname (A.Alam lhs (A.Abody acc)) = do
+ (argnames, aenv) <- genVarsAEnv lhs AVENil
+ destnames <- genAVarsTup (A.arraysR acc)
+ let destShapeDeclSts = [C.SDecl t n Nothing
+ | TypedName t n <- fst (tupanamesList destnames)]
+ outnames <- genAVarsTup (A.arraysR acc)
+ (auxdefs, stmts) <- compileCommands <$> compileAcc' aenv destnames acc
+ return (auxdefs
+ ,C.ProcDef procname
+ (map (\case Left (TypedName t n) -> (t, n)
+ Right (TypedAName t n) -> (t, n))
+ (tupanamesList' argnames)
+ ++
+ map (\case Left (TypedName t n) -> (C.TPtr t, n)
+ Right (TypedAName t n) -> (C.TPtr t, n))
+ (tupanamesList' outnames))
+ (destShapeDeclSts ++
+ stmts ++
+ [C.SStore outn (C.ELit "0") (C.EVar destn)
+ | (outn, destn) <- zipOutSrcNamesT outnames destnames])
+ ,makeCArrTupNames (lhsToTupR lhs) (\(C.TPtr t) -> t) argnames
+ ,makeCArrTupNames (A.arraysR acc) (\(C.TPtr t) -> t) outnames)
+ where
+ makeCArrTupNames :: ArraysR a -> (C.Type -> C.Type) -> TupANames a -> CArrTupNames a
+ makeCArrTupNames (TupRpair t1 t2) typefun (ANPair an1 an2) =
+ CATNPair (makeCArrTupNames t1 typefun an1) (makeCArrTupNames t2 typefun an2)
+ makeCArrTupNames (TupRsingle (ArrayR _ t)) typefun (ANArray shn ans) =
+ CATNArray (makeCShNames shn) (makeCArrNames t typefun ans)
+ makeCArrTupNames TupRunit _ ANIgnore = CATNNil
+ makeCArrTupNames _ _ ANIgnore = error "Ignore of non-nil element in generated names"
+
+ makeCShNames :: ShNames sh -> CShNames sh
+ makeCShNames ShZ = CSNNil
+ makeCShNames (ShS ns n) = CSNSnoc (makeCShNames ns) n
+
+ makeCArrNames :: TypeR a -> (C.Type -> C.Type) -> ANames a -> CArrNames sh a
+ makeCArrNames (TupRpair t1 t2) typefun (ITupPair an1 an2) =
+ CANPair (makeCArrNames t1 typefun an1) (makeCArrNames t2 typefun an2)
+ makeCArrNames (TupRsingle _) typefun (ITupSingle (TypedAName ty n)) =
+ CANName (typefun ty) n
+ makeCArrNames TupRunit _ ITupIgnore = CANNil
+ makeCArrNames _ _ ITupIgnore = error "Ignore of non-nil element in generated names"
+ makeCArrNames _ _ _ = error "Invalid GADTs"
+
+ zipOutSrcNamesT :: TupANames t -> TupANames t -> [(C.Name, C.Name)]
+ zipOutSrcNamesT ANIgnore _ = []
+ zipOutSrcNamesT _ ANIgnore = error "Ignore in source names but not in out names"
+ zipOutSrcNamesT (ANArray shn ns) (ANArray shn' ns') =
+ zipWith (\(TypedName _ n) (TypedName _ n') -> (n, n'))
+ (shnamesList shn) (shnamesList shn')
+ ++ zipOutSrcNames ns ns'
+ zipOutSrcNamesT (ANPair a b) (ANPair a' b') = zipOutSrcNamesT a a' ++ zipOutSrcNamesT b b'
+
+ zipOutSrcNames :: ANames t -> ANames t -> [(C.Name, C.Name)]
+ zipOutSrcNames ITupIgnore _ = []
+ zipOutSrcNames _ ITupIgnore = error "Ignore in source names but not in out names"
+ zipOutSrcNames (ITupPair a b) (ITupPair a' b') = zipOutSrcNames a a' ++ zipOutSrcNames b b'
+ zipOutSrcNames (ITupSingle (TypedAName _ n)) (ITupSingle (TypedAName _ n')) = [(n, n')]
+ zipOutSrcNames _ _ = error "Invalid GADTs"
+compileAfun1 _ _ = throw "Not an array function with exactly one argument"