diff options
Diffstat (limited to 'SC/Afun.hs')
-rw-r--r-- | SC/Afun.hs | 141 |
1 files changed, 141 insertions, 0 deletions
diff --git a/SC/Afun.hs b/SC/Afun.hs new file mode 100644 index 0000000..3379cc6 --- /dev/null +++ b/SC/Afun.hs @@ -0,0 +1,141 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE StandaloneDeriving #-} +module SC.Afun where + +import qualified Data.Array.Accelerate.AST as A +import Data.Array.Accelerate.AST.LeftHandSide +import Data.Array.Accelerate.Representation.Array +import Data.Array.Accelerate.Representation.Type + +import qualified Language.C as C +import SC.Acc +import SC.Defs +import SC.Monad + + +-- | Variable names for a tuple of arrays. Each array is represented in +-- struct-of-arrays form. For example: the type +-- @(Scalar Double, Matrix (Int, Float))@, which is internally represented as +-- @(Array () Double, Array (((), Int), Int) (Int, Float))@, would be +-- described as follows: (the variable names will differ) +-- +-- > CATNPair (CATNArray CSNNil +-- > (CANName TDouble (Name "a"))) +-- > (CATNArray (CSNSnoc (CSNSnoc CSNNil (Name "n1")) (Name "n2")) +-- > (CANPair (CANName (TInt B64) (Name "b")) +-- > (CANName TFloat (Name "c")))) +-- +-- Suppose that the Accelerate function in question has return type +-- @Vector Double@, which is to say @Array ((), Int) Double@, with description: +-- +-- > CATNArray (CSNSnoc CSNNil (Name "m")) +-- > (CANName TDouble (Name "r")) +-- +-- Then its C function definition would look as follows: +-- +-- > void function(double *a, +-- > int64_t n1, int64_t n2, int64_t *b, int64_t *c, +-- > int64_t m, double **r); +-- +-- Note that the first input array array here has zero shape arguments because +-- it is zero-dimensional. +data CArrTupNames a where + CATNPair :: CArrTupNames a -> CArrTupNames b -> CArrTupNames (a, b) + CATNArray :: CShNames sh -> CArrNames sh a -> CArrTupNames (Array sh a) + CATNNil :: CArrTupNames () +deriving instance Show (CArrTupNames a) + +-- | Names for the shape of an array. See 'CArrTupNames' for more information. +-- +-- Note that the names in this structure are are to be interpreted as variables +-- of type @int64_t@. +data CShNames sh where + CSNSnoc :: CShNames sh -> C.Name -> CShNames (sh, Int) + CSNNil :: CShNames () +deriving instance Show (CShNames a) + +-- | Names for a single array. See 'CArrTupNames' for more information. +-- +-- Note that the 'C.Type' in 'CANName' is the /element/ type of the array. +data CArrNames sh a where + CANPair :: CArrNames sh a -> CArrNames sh b -> CArrNames sh (a, b) + CANName :: C.Type -> C.Name -> CArrNames sh a + CANNil :: CArrNames sh () +deriving instance Show (CArrNames sh a) + +-- | The function passed should have exactly one argument (that may consist of +-- multiple arrays in a tuple, of course). +-- +-- The result consists of: +-- 1. An array of auxiliary function definitions that the program needs, on top +-- of the prelude. +-- 2. The function that implements the top-level Accelerate array function. +-- 3. The variable names corresponding to the components of the argument. +-- 4. The variable names corresponding to the components of the result. These +-- are double-pointer arguments to the function in (2.). +-- +-- For an example, see the documentation of 'CArrTupNames'. +compileAfun1 :: C.Name + -> A.Afun (a -> b) + -> SC ([C.FunDef], C.FunDef, CArrTupNames a, CArrTupNames b) +compileAfun1 procname (A.Alam lhs (A.Abody acc)) = do + (argnames, aenv) <- genVarsAEnv lhs AVENil + destnames <- genAVarsTup (A.arraysR acc) + let destShapeDeclSts = [C.SDecl t n Nothing + | TypedName t n <- fst (tupanamesList destnames)] + outnames <- genAVarsTup (A.arraysR acc) + (auxdefs, stmts) <- compileCommands <$> compileAcc' aenv destnames acc + return (auxdefs + ,C.ProcDef procname + (map (\case Left (TypedName t n) -> (t, n) + Right (TypedAName t n) -> (t, n)) + (tupanamesList' argnames) + ++ + map (\case Left (TypedName t n) -> (C.TPtr t, n) + Right (TypedAName t n) -> (C.TPtr t, n)) + (tupanamesList' outnames)) + (destShapeDeclSts ++ + stmts ++ + [C.SStore outn (C.ELit "0") (C.EVar destn) + | (outn, destn) <- zipOutSrcNamesT outnames destnames]) + ,makeCArrTupNames (lhsToTupR lhs) (\(C.TPtr t) -> t) argnames + ,makeCArrTupNames (A.arraysR acc) (\(C.TPtr t) -> t) outnames) + where + makeCArrTupNames :: ArraysR a -> (C.Type -> C.Type) -> TupANames a -> CArrTupNames a + makeCArrTupNames (TupRpair t1 t2) typefun (ANPair an1 an2) = + CATNPair (makeCArrTupNames t1 typefun an1) (makeCArrTupNames t2 typefun an2) + makeCArrTupNames (TupRsingle (ArrayR _ t)) typefun (ANArray shn ans) = + CATNArray (makeCShNames shn) (makeCArrNames t typefun ans) + makeCArrTupNames TupRunit _ ANIgnore = CATNNil + makeCArrTupNames _ _ ANIgnore = error "Ignore of non-nil element in generated names" + + makeCShNames :: ShNames sh -> CShNames sh + makeCShNames ShZ = CSNNil + makeCShNames (ShS ns n) = CSNSnoc (makeCShNames ns) n + + makeCArrNames :: TypeR a -> (C.Type -> C.Type) -> ANames a -> CArrNames sh a + makeCArrNames (TupRpair t1 t2) typefun (ITupPair an1 an2) = + CANPair (makeCArrNames t1 typefun an1) (makeCArrNames t2 typefun an2) + makeCArrNames (TupRsingle _) typefun (ITupSingle (TypedAName ty n)) = + CANName (typefun ty) n + makeCArrNames TupRunit _ ITupIgnore = CANNil + makeCArrNames _ _ ITupIgnore = error "Ignore of non-nil element in generated names" + makeCArrNames _ _ _ = error "Invalid GADTs" + + zipOutSrcNamesT :: TupANames t -> TupANames t -> [(C.Name, C.Name)] + zipOutSrcNamesT ANIgnore _ = [] + zipOutSrcNamesT _ ANIgnore = error "Ignore in source names but not in out names" + zipOutSrcNamesT (ANArray shn ns) (ANArray shn' ns') = + zipWith (\(TypedName _ n) (TypedName _ n') -> (n, n')) + (shnamesList shn) (shnamesList shn') + ++ zipOutSrcNames ns ns' + zipOutSrcNamesT (ANPair a b) (ANPair a' b') = zipOutSrcNamesT a a' ++ zipOutSrcNamesT b b' + + zipOutSrcNames :: ANames t -> ANames t -> [(C.Name, C.Name)] + zipOutSrcNames ITupIgnore _ = [] + zipOutSrcNames _ ITupIgnore = error "Ignore in source names but not in out names" + zipOutSrcNames (ITupPair a b) (ITupPair a' b') = zipOutSrcNames a a' ++ zipOutSrcNames b b' + zipOutSrcNames (ITupSingle (TypedAName _ n)) (ITupSingle (TypedAName _ n')) = [(n, n')] + zipOutSrcNames _ _ = error "Invalid GADTs" +compileAfun1 _ _ = throw "Not an array function with exactly one argument" |