{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE StandaloneDeriving #-} module SC.Afun where import qualified Data.Array.Accelerate.AST as A import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Type import qualified Language.C as C import SC.Acc import SC.Defs import SC.Monad -- | Variable names for a tuple of arrays. Each array is represented in -- struct-of-arrays form. For example: the type -- @(Scalar Double, Matrix (Int, Float))@, which is internally represented as -- @(Array () Double, Array (((), Int), Int) (Int, Float))@, would be -- described as follows: (the variable names will differ) -- -- > CATNPair (CATNArray CSNNil -- > (CANName TDouble (Name "a"))) -- > (CATNArray (CSNSnoc (CSNSnoc CSNNil (Name "n1")) (Name "n2")) -- > (CANPair (CANName (TInt B64) (Name "b")) -- > (CANName TFloat (Name "c")))) -- -- Suppose that the Accelerate function in question has return type -- @Vector Double@, which is to say @Array ((), Int) Double@, with description: -- -- > CATNArray (CSNSnoc CSNNil (Name "m")) -- > (CANName TDouble (Name "r")) -- -- Then its C function definition would look as follows: -- -- > void function(double *a, -- > int64_t n1, int64_t n2, int64_t *b, int64_t *c, -- > int64_t m, double **r); -- -- Note that the first input array array here has zero shape arguments because -- it is zero-dimensional. data CArrTupNames a where CATNPair :: CArrTupNames a -> CArrTupNames b -> CArrTupNames (a, b) CATNArray :: CShNames sh -> CArrNames sh a -> CArrTupNames (Array sh a) CATNNil :: CArrTupNames () deriving instance Show (CArrTupNames a) -- | Names for the shape of an array. See 'CArrTupNames' for more information. -- -- Note that the names in this structure are are to be interpreted as variables -- of type @int64_t@. data CShNames sh where CSNSnoc :: CShNames sh -> C.Name -> CShNames (sh, Int) CSNNil :: CShNames () deriving instance Show (CShNames a) -- | Names for a single array. See 'CArrTupNames' for more information. -- -- Note that the 'C.Type' in 'CANName' is the /element/ type of the array. data CArrNames sh a where CANPair :: CArrNames sh a -> CArrNames sh b -> CArrNames sh (a, b) CANName :: C.Type -> C.Name -> CArrNames sh a CANNil :: CArrNames sh () deriving instance Show (CArrNames sh a) -- | The function passed should have exactly one argument (that may consist of -- multiple arrays in a tuple, of course). -- -- The result consists of: -- 1. An array of auxiliary function definitions that the program needs, on top -- of the prelude. -- 2. The function that implements the top-level Accelerate array function. -- 3. The variable names corresponding to the components of the argument. -- 4. The variable names corresponding to the components of the result. These -- are double-pointer arguments to the function in (2.). -- -- For an example, see the documentation of 'CArrTupNames'. compileAfun1 :: C.Name -> A.Afun (a -> b) -> SC ([C.FunDef], C.FunDef, CArrTupNames a, CArrTupNames b) compileAfun1 procname (A.Alam lhs (A.Abody acc)) = do (argnames, aenv) <- genVarsAEnv lhs AVENil destnames <- genAVarsTup (A.arraysR acc) let destShapeDeclSts = [C.SDecl t n Nothing | TypedName t n <- fst (tupanamesList destnames)] outnames <- genAVarsTup (A.arraysR acc) (auxdefs, stmts) <- compileCommands <$> compileAcc' aenv destnames acc return (auxdefs ,C.ProcDef procname (map (\case Left (TypedName t n) -> (t, n) Right (TypedAName t n) -> (t, n)) (tupanamesList' argnames) ++ map (\case Left (TypedName t n) -> (C.TPtr t, n) Right (TypedAName t n) -> (C.TPtr t, n)) (tupanamesList' outnames)) (destShapeDeclSts ++ stmts ++ [C.SStore outn (C.ELit "0") (C.EVar destn) | (outn, destn) <- zipOutSrcNamesT outnames destnames]) ,makeCArrTupNames (lhsToTupR lhs) (\(C.TPtr t) -> t) argnames ,makeCArrTupNames (A.arraysR acc) (\(C.TPtr t) -> t) outnames) where makeCArrTupNames :: ArraysR a -> (C.Type -> C.Type) -> TupANames a -> CArrTupNames a makeCArrTupNames (TupRpair t1 t2) typefun (ANPair an1 an2) = CATNPair (makeCArrTupNames t1 typefun an1) (makeCArrTupNames t2 typefun an2) makeCArrTupNames (TupRsingle (ArrayR _ t)) typefun (ANArray shn ans) = CATNArray (makeCShNames shn) (makeCArrNames t typefun ans) makeCArrTupNames TupRunit _ ANIgnore = CATNNil makeCArrTupNames _ _ ANIgnore = error "Ignore of non-nil element in generated names" makeCShNames :: ShNames sh -> CShNames sh makeCShNames ShZ = CSNNil makeCShNames (ShS ns n) = CSNSnoc (makeCShNames ns) n makeCArrNames :: TypeR a -> (C.Type -> C.Type) -> ANames a -> CArrNames sh a makeCArrNames (TupRpair t1 t2) typefun (ITupPair an1 an2) = CANPair (makeCArrNames t1 typefun an1) (makeCArrNames t2 typefun an2) makeCArrNames (TupRsingle _) typefun (ITupSingle (TypedAName ty n)) = CANName (typefun ty) n makeCArrNames TupRunit _ ITupIgnore = CANNil makeCArrNames _ _ ITupIgnore = error "Ignore of non-nil element in generated names" makeCArrNames _ _ _ = error "Invalid GADTs" zipOutSrcNamesT :: TupANames t -> TupANames t -> [(C.Name, C.Name)] zipOutSrcNamesT ANIgnore _ = [] zipOutSrcNamesT _ ANIgnore = error "Ignore in source names but not in out names" zipOutSrcNamesT (ANArray shn ns) (ANArray shn' ns') = zipWith (\(TypedName _ n) (TypedName _ n') -> (n, n')) (shnamesList shn) (shnamesList shn') ++ zipOutSrcNames ns ns' zipOutSrcNamesT (ANPair a b) (ANPair a' b') = zipOutSrcNamesT a a' ++ zipOutSrcNamesT b b' zipOutSrcNames :: ANames t -> ANames t -> [(C.Name, C.Name)] zipOutSrcNames ITupIgnore _ = [] zipOutSrcNames _ ITupIgnore = error "Ignore in source names but not in out names" zipOutSrcNames (ITupPair a b) (ITupPair a' b') = zipOutSrcNames a a' ++ zipOutSrcNames b b' zipOutSrcNames (ITupSingle (TypedAName _ n)) (ITupSingle (TypedAName _ n')) = [(n, n')] zipOutSrcNames _ _ = error "Invalid GADTs" compileAfun1 _ _ = throw "Not an array function with exactly one argument"